Big Data Analytics using Apache Spark

CN6022 - Big Data Infrastructure and Manipulation

Authors

Jayrup Nakawala (u2613621)

Yogi Patel (u2536809)

Jasmi Alasapuri (u2571395)

1 Introduction

1.1 Project Overview

This project aims to demonstrate advanced Big Data manipulation and analytics using Apache Spark SQL. While the module provided a baseline dataset (Air Flight Status), we have elected to utilize the ESA Gaia Data Release 3 (DR3) for this analysis. Gaia is widely considered the largest and most complex astronomical catalog in human history, containing astrometric and photometric data for over 1.8 billion sources.

1.2 Dataset Selection and Justification

The core requirement for this coursework was to utilize a dataset that exceeds the volume and complexity of the provided sample. The Gaia DR3 dataset fits this criterion perfectly for three reasons:

  1. Volume: Even a 1% subset of Gaia (approximately 3 million rows) significantly exceeds the size of the standard flight dataset, requiring distributed computing techniques to process efficiently.
  2. Complexity: Unlike the flat structure of flight logs, astronomical data requires complex feature engineering (e.g., calculating Absolute Magnitude from Parallax).
  3. Scientific Relevance: This dataset allows for genuine astrophysical discovery including the identification of White Dwarfs and Binary Systems.

1.3 Data Acquisition Strategy: The “Two-Tier” Architecture

Ingesting the entire 1.8 billion row catalog is computationally infeasible for this project’s scope. Furthermore, a simple random sample introduces Malmquist Bias, where bright, distant stars drown out faint, local objects. To resolve this, we architected a Two-Tier Data Strategy, acquiring two distinct datasets via the ESA Archive (ADQL):

  • Dataset A: The “Galactic Survey” (Macro-Analysis): A random 1% sample of the entire sky (~3 million rows). This “Deep Field” dataset is used to map the broad structure of the Milky Way and analyze general stellar demographics, breadth over precision
  • Dataset B: The “Local Bubble” (Micro-Physics): A volume-limited sample of all stars within 100 parsecs (\(distance \le 100pc\)). This high-precision dataset eliminates distance-related noise, allowing us to detect faint objects like White Dwarfs that would otherwise be invisible, precision over breadth.

1.3.1 Data Schema and Key Columns

The following columns from the Gaia dataset will be used in our analysis:

  • source_id: Unique identifier for each star. (64-bit Integer)
  • ra: Right Ascension (celestial longitude). (Double Precision)
  • dec: Declination (celestial latitude). (Double Precision)
  • parallax: Parallax in milliarcseconds, used to calculate distance (\(d = 1/p\)). (Double Precision)
  • parallax_error: The uncertainty in the parallax measurement. (Single Precision)
  • pmra: Proper motion in the direction of Right Ascension. (Double Precision)
  • pmdec: Proper motion in the direction of Declination. (Double Precision)
  • phot_g_mean_mag: Mean apparent magnitude in the G-band (a measure of brightness as seen from Earth). (Single Precision)
  • bp_rp: The blue-red color index, a proxy for the star’s surface temperature. (Single Precision)
  • teff_gspphot: Effective temperature of the star’s photosphere, derived from photometry. (Single Precision)

Stellar definitions

Stellar definitions

1.4 Team Structure and Objectives

The analysis is divided into three distinct workstreams, each focusing on a different aspect of the data:

  • Jasmi (Stellar Demographics): Focuses on classifying star populations (H-R Diagram) and identifying high-velocity outliers using the Galactic Survey.
  • Yogi (Galactic Structure): detailed mapping of the Milky Way’s density and analysis of measurement error rates across the sky.
  • Jayrup (Exotic Star Hunting): Utilizes the high-precision “Local Bubble” and “Galactic Survey” data to detect rare stellar remnants and gravitationally bound binary star systems.

1.5 Understanding the Data

1.5.1 Installing dependencies

Code
!pip install pyspark astroquery pandas pyarrow seaborn matplotlib --quiet
zsh:1: command not found: pip

1.5.2 Downloading the Datasets

Code
from astroquery.gaia import Gaia
import os

output_dir = "../data"
os.makedirs(output_dir, exist_ok=True)

def save_strict_parquet(results, filename):
    """
    Converts Astropy Table to Pandas with strict Gaia Data Model types.
    """
    df = results.to_pandas()
    
    # 1. Enforce Source_ID as 64-bit Integer (Long)
    df['source_id'] = df['source_id'].astype('int64')

    # 2. Enforce Double Precision (float64) for Angles/Velocity
    doubles = ['ra', 'dec', 'parallax', 'pmra', 'pmdec']
    for col in doubles:
        if col in df.columns:
            df[col] = df[col].astype('float64')

    # 3. Enforce Single Precision (float32) for Errors/Magnitudes
    # This saves 50% RAM on these columns vs standard floats.
    floats = ['parallax_error', 'bp_rp', 'phot_g_mean_mag','teff_gspphot']
    for col in floats:
        if col in df.columns:
            df[col] = df[col].astype('float32')
            
    # Save
    print(f">> Saving {len(df)} rows to {filename}...")
    df.to_parquet(filename, index=False)


# --- JOB 1: SURVEY ---
survey_file = os.path.join(output_dir, "gaia_survey.parquet")
if not os.path.exists(survey_file):
    print(">> Downloading Survey...")
    q = """
    SELECT source_id, ra, dec, parallax, parallax_error, pmra, pmdec, 
           phot_g_mean_mag, bp_rp, teff_gspphot
    FROM gaiadr3.gaia_source
    WHERE parallax > 0 AND phot_g_mean_mag < 19 AND random_index < 3000000
    """
    job = Gaia.launch_job_async(q)
    save_strict_parquet(job.get_results(), survey_file)
else:
    print(">> Survey already downloaded.")

# --- JOB 2: LOCAL BUBBLE ---
local_file = os.path.join(output_dir, "gaia_100pc.parquet")
if not os.path.exists(local_file):
    print(">> Downloading Local Bubble...")
    q = """
    SELECT source_id, ra, dec, parallax, parallax_error, pmra, pmdec, 
           phot_g_mean_mag, bp_rp, teff_gspphot
    FROM gaiadr3.gaia_source
    WHERE parallax >= 10 AND parallax_over_error > 5
    """
    job = Gaia.launch_job_async(q)
    save_strict_parquet(job.get_results(), local_file)
else:
    print(">> Local Bubble already downloaded.")

print(">> Done.")
Workaround solutions for the Gaia Archive issues following the infrastructure upgrade: https://www.cosmos.esa.int/web/gaia/news#WorkaroundArchive
>> Survey already downloaded.
>> Local Bubble already downloaded.
>> Done.

1.5.3 Exploring the Datasets

Code
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count

# Initialize Spark
spark = SparkSession.builder \
    .appName("Gaia_Data_Exploration") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

# Load Datasets
survey_path = "../data/gaia_survey.parquet"
local_path = "../data/gaia_100pc.parquet"

df_survey = spark.read.parquet(survey_path)
df_local = spark.read.parquet(local_path)

print(">>> DATASET 1: GALACTIC SURVEY (Macro)")
print(f"Total Rows: {df_survey.count():,}")
df_survey.printSchema()

print("\n>>> DATASET 2: LOCAL BUBBLE (Micro)")
print(f"Total Rows: {df_local.count():,}")
df_local.printSchema()

# ====================================================
# 1. PHYSICAL COMPARISON
# ====================================================

print("\n>>> STATISTICAL COMPARISON: PARALLAX (Distance)")
print("Note: Distance (pc) is approx 1000 / parallax.")

print("-- Survey Dataset Stats --")
df_survey.select("parallax", "phot_g_mean_mag", "pmra").describe().show()

print("-- Local Bubble Stats --")
df_local.select("parallax", "phot_g_mean_mag", "pmra").describe().show()

# ====================================================
# 2. QUALITY CHECK (Null Analysis)
# ====================================================

def count_nulls(df):
    return df.select([count(when(col(c).isNull(), c)).alias(c) for c in df.columns])

# We only check critical columns
from pyspark.sql.functions import when
cols_to_check = ["parallax", "pmra", "teff_gspphot", "bp_rp"]

print("\n>>> NULL VALUE ANALYSIS (Survey Dataset)")
df_survey.select([count(when(col(c).isNull(), c)).alias(c) for c in cols_to_check]).show()

print("\n>>> NULL VALUE ANALYSIS (Local Dataset)")
df_local.select([count(when(col(c).isNull(), c)).alias(c) for c in cols_to_check]).show()
WARNING: Using incubator modules: jdk.incubator.vector
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
>>> DATASET 1: GALACTIC SURVEY (Macro)
Total Rows: 844,868
root
 |-- source_id: long (nullable = true)
 |-- ra: double (nullable = true)
 |-- dec: double (nullable = true)
 |-- parallax: double (nullable = true)
 |-- parallax_error: float (nullable = true)
 |-- pmra: double (nullable = true)
 |-- pmdec: double (nullable = true)
 |-- phot_g_mean_mag: float (nullable = true)
 |-- bp_rp: float (nullable = true)
 |-- teff_gspphot: float (nullable = true)


>>> DATASET 2: LOCAL BUBBLE (Micro)
Total Rows: 541,958
root
 |-- source_id: long (nullable = true)
 |-- ra: double (nullable = true)
 |-- dec: double (nullable = true)
 |-- parallax: double (nullable = true)
 |-- parallax_error: float (nullable = true)
 |-- pmra: double (nullable = true)
 |-- pmdec: double (nullable = true)
 |-- phot_g_mean_mag: float (nullable = true)
 |-- bp_rp: float (nullable = true)
 |-- teff_gspphot: float (nullable = true)


>>> STATISTICAL COMPARISON: PARALLAX (Distance)
Note: Distance (pc) is approx 1000 / parallax.
-- Survey Dataset Stats --
+-------+--------------------+------------------+-------------------+
|summary|            parallax|   phot_g_mean_mag|               pmra|
+-------+--------------------+------------------+-------------------+
|  count|              844868|            844868|             844868|
|   mean|    0.55071253992443|17.404869147894882| -2.373963360005405|
| stddev|  0.7697707648724794|1.4361723326971574| 7.0091972840050865|
|    min|1.868718209532133E-7|         3.0238004|-377.74180694687686|
|    max|   75.56887569382528|         18.999998|  649.0319386167508|
+-------+--------------------+------------------+-------------------+

-- Local Bubble Stats --
+-------+------------------+-----------------+------------------+
|summary|          parallax|  phot_g_mean_mag|              pmra|
+-------+------------------+-----------------+------------------+
|  count|            541958|           540859|            541958|
|   mean|14.283010896409225|16.94108547991132|-3.085397329690918|
| stddev|7.0432588218793475|3.605457797042056|  94.7899342646246|
|    min|10.000005410606198|        1.9425238|-4406.469178827325|
|    max| 768.0665391873573|        21.289928| 6765.995136250774|
+-------+------------------+-----------------+------------------+


>>> NULL VALUE ANALYSIS (Survey Dataset)
+--------+----+------------+-----+
|parallax|pmra|teff_gspphot|bp_rp|
+--------+----+------------+-----+
|       0|   0|      125194|23080|
+--------+----+------------+-----+


>>> NULL VALUE ANALYSIS (Local Dataset)
+--------+----+------------+-----+
|parallax|pmra|teff_gspphot|bp_rp|
+--------+----+------------+-----+
|       0|   0|      423880|68171|
+--------+----+------------+-----+

2 Queries

2.1 Jasmi

2.1.1 The H-R Diagram

To prepare a high-quality, filtered dataset by calculating the luminosity (Absolute Magnitude, \(M_G\)) for every star in the survey, providing the necessary data for the H-R Diagram’s Y-axis.

Methodology

This phase uses a Direct Calculation and Quality Filtering approach in Spark SQL. The complexity lies not in aggregation, but in applying the astronomical transformation formula and enforcing a high Signal-to-Noise Ratio (SNR) on the distance data before transferring the results to Python for advanced plotting.

Parameter Justification

To ensure the resulting H-R Diagram is scientifically precise—preventing the smearing of the Main Sequence, the parameters were justified and applied:

  • Absolute Magnitude (\(M_G\)) Calculation The luminosity is calculated using the standard formula :

    \[ M_G = m - 5 \log_{10}(d) + 5 \]

    where \(m\) is the apparent magnitude (phot_g_mean_mag) and \(d\) is the distance in parsecs (derived from \(\frac{1000}{\text{parallax}}\)). Using \(M_G\) is mandatory because the H-R diagram plots luminosity, which is independent of Earth’s viewing distance.

  • Colour Index (bp_rp) This direct photometric measurement provides the most reliable measure of the star’s effective surface temperature (the X-axis of the H-R Diagram). It is directly selected for the plot without modification.

  • Quality Cut (Parallax / Parallax Error \(\geq 5.0\)) The signal-to-noise ratio (SNR) of the parallax is calculated dynamically and filtered to values greater than 5.0. This Critical SNR Cut is the most important filter as it ensures that only stars with highly reliable distance estimates are processed. Stars with poor parallax data would otherwise cause the Main Sequence to appear thick and indistinct, masking key stellar populations.

The Query Logic

The analysis was performed using a single, efficient query that focused purely on mathematical transformation and data exclusion:

  • Derivation The Absolute Magnitude (\(\text{abs\_mag}\)) was calculated directly in the SELECT clause using the \(\log_{10}\) function applied to the parallax.

  • Sanitisation The WHERE clause performed filtering to remove all low-quality and non-physical measurements:

    • Exclusion of stars with invalid distance (parallax > 0).
    • Exclusion of stars missing essential photometry (bp_rp IS NOT NULL and phot_g_mean_mag IS NOT NULL).
    • Exclusion of all data points failing the SNR \(\geq 5.0\) standard.
  • Data Transfer The final output was checked for total row count and then converted from a high-performance Spark DataFrame (raw_df) to a standard Pandas DataFrame (pdf). This transfer is necessary to facilitate the advanced plotting capabilities of Matplotlib and NumPy.

Code
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, log10, lit, count, when, sqrt, pow, percentile_approx
from pyspark.sql import functions as F
from pyspark.sql.window import Window
import os

import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

from matplotlib import colors
spark = SparkSession.builder \
    .appName("Gaia_HR_Analysis") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

# Loading Data
df_survey = spark.read.parquet("../data/gaia_survey.parquet")
# Creating a temp view
df_survey.createOrReplaceTempView("gaia_survey")

#Query
query_raw = """
SELECT 
    source_id,
    bp_rp,
    -- Calculate Absolute Magnitude (M) directly in Spark
    -- Formula: Apparent Mag - 5 * log10(Distance) + 5
    (phot_g_mean_mag - 5 * LOG10(1000 / parallax) + 5) AS abs_mag
FROM gaia_survey
WHERE 
    parallax > 0 
    AND (parallax / parallax_error) >= 5.0  -- For better precision 
    AND bp_rp IS NOT NULL 
    AND phot_g_mean_mag IS NOT NULL
"""

print(">>> RUNNING QUERY...")
raw_df = spark.sql(query_raw)

# Check how many stars to plot
print(f"Total stars to plot: {raw_df.count():,}")

# Convert to Pandas
pdf = raw_df.toPandas()
print("Data loaded.")

# checking the top 5 rows
print("\n------------------------------\nPrinting first five rows\n------------------------------")
pdf.head()
25/12/18 09:37:37 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
>>> RUNNING QUERY...
Total stars to plot: 296,983
Data loaded.

------------------------------
Printing first five rows
------------------------------
source_id bp_rp abs_mag
0 242682832483968 2.051584 9.188913
1 5474605834008192 0.437714 12.840361
2 6591404705437056 1.121918 6.156537
3 6640539131235840 2.003047 6.955578
4 7473968945265152 1.275984 5.436462

Visualisation Logic

The final visualization step performed in Python using the data prepared by this query has:

  • Advanced Density Scatter: Instead of using fixed SQL bins, the data is plotted using a density-mapped scatter plot.
  • NumPy Density Trick: NumPy’s histogram2d function is used to assign a density score to every individual point.
  • Visual Enhancement: The data is sorted by this density score and plotted with a logarithmic colour scale (cmap='inferno'). This technique ensures that the densest region (the Main Sequence core) is plotted last and remains sharp, while also making faint, low-density features (like the White Dwarf sequence) clearly visible.
Code
nbins = 300
k = colors.LogNorm() 

x = pdf['bp_rp'].values
y = pdf['abs_mag'].values

plt.style.use('dark_background') # Dark background makes the colors pop

# Create the grid
H, xedges, yedges = np.histogram2d(x, y, bins=nbins)

# Map every star to its bin
x_inds = np.clip(np.digitize(x, xedges) - 1, 0, nbins - 1)
y_inds = np.clip(np.digitize(y, yedges) - 1, 0, nbins - 1)

# Assign density value to each star
z = H[x_inds, y_inds]

# 3. Sort for Sharpness
# Sort the data so the densest (brightest) regions are plotted LAST (on top)
idx = z.argsort()
x, y, z = x[idx], y[idx], z[idx]

# 4. Plot: Use 'magma' colormap for clearer density progression
plt.figure(figsize=(8, 10))

# 'magma' or 'plasma' provide excellent contrast for astronomical density plots
plt.scatter(x, y, c=z, s=0.5, cmap='magma', norm=k, alpha=1.0, edgecolors='none')

# 5. Astronomy Polish
plt.title("Gaia DR3 Hertzsprung-Russell Diagram", fontsize=16)
plt.xlabel("Color Index (BP-RP)", fontsize=12)
plt.ylabel("Absolute Magnitude (M)", fontsize=12)

# Invert Y-Axis (Bright stars go at the top)
plt.ylim(17, -5)
plt.xlim(-1, 5)

# Add Colorbar
cbar = plt.colorbar()
cbar.set_label('Star Density', rotation=270, labelpad=20)

# Annotations (FIXED: Changed colour and added bolding for prominence)
plt.text(0.5, 14, 'White Dwarfs', color='white', fontsize=12, fontweight='bold')
plt.text(1.5, 4, 'Main Sequence', color='gold', fontsize=12, rotation=-45, fontweight='bold')
plt.text(2.5, 0, 'Giants', color='orange', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

Advanced Density Scatter

Advanced Density Scatter

2.1.2 Stellar Population Census & Astrophysical Classification

To classify the stellar population into major spectral types (O through M) and to distinguish between Main Sequence stars (dwarfs) and evolved stars (giants).

Methodology

This query uses a Colour–Magnitude approach rather than relying on raw temperature estimates. This method provides more reliable classification, particularly for faint sources where temperature estimations can become inaccurate.

Dataset

Gaia DR3 Source (Random 3 Million Object Sample)

Parameter Justification

To ensure scientific accuracy in the classification process, the following parameters were selected:

  • Colour Index (bp_rp)

  • Quality Cut (Parallax / Parallax Error > 5)

  • Absolute Magnitude (\(M_G\))

Code
# 1. Query
query = """
WITH PhysicalProperties AS (
    SELECT
        source_id,
        bp_rp,
        -- Calculate Absolute Magnitude (Mg)
        phot_g_mean_mag - 5 * LOG10(1000 / parallax) + 5 AS abs_mag_g,
        
        -- FIX: Calculate the Signal-to-Noise ratio manually
        (parallax / parallax_error) AS calculated_poe
    FROM gaia_survey
    WHERE parallax > 0 
      AND parallax_error > 0 -- Prevent divide-by-zero errors
),

FilteredStars AS (
    SELECT * FROM PhysicalProperties
    WHERE calculated_poe > 5 -- Quality Cut: Only keep reliable data
),

ClassifiedStars AS (
    SELECT
        -- 1. Spectral Classification (Colour)
        CASE
            WHEN bp_rp IS NULL THEN 'Unknown'
            WHEN bp_rp > 1.6 THEN 'Cool (M)'
            WHEN bp_rp > 0.9 THEN 'Warm (K)'
            WHEN bp_rp > 0.7 THEN 'Yellow (G)'
            WHEN bp_rp > 0.4 THEN 'Yel-Wht (F)'
            WHEN bp_rp > 0.0 THEN 'White (A)'
            ELSE 'Hot (O/B)'
        END AS spectral_type,

        -- 2. Luminosity Classification (Giant vs Dwarf)
        CASE
            WHEN abs_mag_g < 3.5 AND bp_rp > 0.7 THEN 'Giant'
            ELSE 'Main Sequence (Dwarf)'
        END AS luminosity_class
    FROM FilteredStars
)

SELECT
    spectral_type,
    COUNT(CASE WHEN luminosity_class = 'Main Sequence (Dwarf)' THEN 1 END) AS dwarf_count,
    COUNT(CASE WHEN luminosity_class = 'Giant' THEN 1 END) AS giant_count,
    COUNT(*) AS total_count,
    ROUND(100.0 * COUNT(CASE WHEN luminosity_class = 'Giant' THEN 1 END) / COUNT(*), 1) AS giant_percentage
FROM ClassifiedStars
WHERE spectral_type != 'Unknown'
GROUP BY spectral_type
ORDER BY 
    CASE spectral_type
        WHEN 'Hot (O/B)' THEN 1
        WHEN 'White (A)' THEN 2
        WHEN 'Yel-Wht (F)' THEN 3
        WHEN 'Yellow (G)' THEN 4
        WHEN 'Warm (K)' THEN 5
        WHEN 'Cool (M)' THEN 6
        ELSE 7
    END
"""

# 2. Execute
df_results = spark.sql(query)

# 3. Show Results
print(">> Census Results: Dwarfs vs. Giants (calculated_poe fix)")
df_results.show(truncate=False)
>> Census Results: Dwarfs vs. Giants (calculated_poe fix)
+-------------+-----------+-----------+-----------+----------------+
|spectral_type|dwarf_count|giant_count|total_count|giant_percentage|
+-------------+-----------+-----------+-----------+----------------+
|Hot (O/B)    |94         |0          |94         |0.0             |
|White (A)    |745        |0          |745        |0.0             |
|Yel-Wht (F)  |4536       |0          |4536       |0.0             |
|Yellow (G)   |21349      |6598       |27947      |23.6            |
|Warm (K)     |143669     |21220      |164889     |12.9            |
|Cool (M)     |84014      |14758      |98772      |14.9            |
+-------------+-----------+-----------+-----------+----------------+

Visualisation Logic

To clearly present the results of the analysis, two visualisations were produced:

  • Population Mix (Stacked Bar Chart)
    A 100% stacked bar chart was used to illustrate the proportion of giants versus dwarfs within each spectral type. This visualisation highlights the evolutionary composition of the stellar population and allows for direct comparison across spectral classes.
Code
pdf = df_results.toPandas()

# 2. Setup the Data for Plotting
# We need to sort the data from Hot to Cool for the X-axis
sort_order = {
    'Hot (O/B)': 0, 'White (A)': 1, 'Yel-Wht (F)': 2, 
    'Yellow (G)': 3, 'Warm (K)': 4, 'Cool (M)': 5
}
pdf['sort_id'] = pdf['spectral_type'].map(sort_order)
pdf = pdf.sort_values('sort_id')

# Calculate Percentages for the Stacked Bar
# (We re-calculate here to ensure they sum to exactly 100 for the plot)
pdf['dwarf_pct'] = (pdf['dwarf_count'] / pdf['total_count']) * 100
pdf['giant_pct'] = (pdf['giant_count'] / pdf['total_count']) * 100

# 3. Create the Plot
fig, ax = plt.subplots(figsize=(10, 6))

# Plot Dwarfs (Bottom Bar)
p1 = ax.bar(pdf['spectral_type'], pdf['dwarf_pct'], label='Main Sequence (Dwarfs)', 
            color='#1f77b4', edgecolor='black', alpha=0.9)

# Plot Giants (Top Bar)
p2 = ax.bar(pdf['spectral_type'], pdf['giant_pct'], bottom=pdf['dwarf_pct'], 
            label='Giants (Evolved)', color='#d62728', edgecolor='black', alpha=0.9)

# 4. Styling
ax.set_title('Stellar Population Mix: Dwarfs vs. Giants (High-Quality Subset)', fontsize=16)
ax.set_ylabel('Percentage of Population (%)', fontsize=12)
ax.set_xlabel('Spectral Type', fontsize=12)
ax.set_ylim(0, 100)
ax.legend(loc='upper left', frameon=True)
ax.grid(axis='y', linestyle='--', alpha=0.4)

# 5. Add Labels
# Label the Giants if they exist
for i, (idx, row) in enumerate(pdf.iterrows()):
    if row['giant_pct'] > 1:
        ax.text(i, row['dwarf_pct'] + row['giant_pct']/2, f"{row['giant_pct']:.1f}%", 
                ha='center', va='center', color='white', fontweight='bold', fontsize=11)
        
    # Label the Dwarfs
    if row['dwarf_pct'] > 5:
        ax.text(i, row['dwarf_pct']/2, f"{row['dwarf_pct']:.1f}%", 
                ha='center', va='center', color='white', fontweight='bold', fontsize=11)

plt.style.use('dark_background') # Dark background makes the colors pop
plt.tight_layout()
plt.show()

Population Mix (Stacked Bar Chart)

Population Mix (Stacked Bar Chart)

Interpretation of Results

  • O, B, A, and F Spectral Types
    The analysis identified 0.0% giant stars within these spectral classes. This result is consistent with stellar evolution theory, as massive, blue stars evolve rapidly and do not remain in the blue region of the spectrum once they leave the Main Sequence.

  • G and K Spectral Types
    A substantial giant population was observed within these classes, accounting for approximately 13–24% of the sample. This correctly traces the Red Giant Branch and distinguishes evolved stars, such as Arcturus, from nearby solar-type dwarfs.

  • M Spectral Type
    The M-type population was overwhelmingly dominated by dwarfs. This reflects the high abundance of red dwarfs in the galaxy and the relative rarity of true M-type giants in a randomly selected stellar sample.

2.1.3 High-Velocity Outlier Detection (Kinematic Analysis)

To identify and flag the top 1% of stars in the df_local dataset exhibiting the highest Total Proper Motion (apparent speed across the sky). These stars are often key kinematic outliers, such as halo stars or nearby high-velocity dwarfs.

  • Columns Needed: pmra (Proper Motion Right Ascension), pmdec (Proper Motion Declination).

  • SQL Complexity: Simplified and Optimised. The executed code avoids the slow, complex nested SQL window function (PERCENT_RANK()) proposed in the original plan, replacing it with a direct, single-action calculation in Spark.

    1. Mathematical Transformation: Total Proper Motion (\(\mu\)) is calculated using the Pythagorean theorem: \(\mu = \sqrt{\mu_{\alpha}^2 + \mu_{\delta}^2}\).
      • Code: sqrt(pow(col("pmra"), 2) + pow(col("pmdec"), 2))
    2. Threshold Calculation: The complexity is offloaded to the optimized Spark function percentile_approx(). This directly computes the 99th percentile proper motion value (pm_threshold) in a single, fast aggregation.
      • Code: df_motion.agg(percentile_approx("total_pm", lit(0.99)))
    3. Flagging (Simple Filter): The final SQL query becomes a simple filter applied to the pm_threshold, avoiding a subquery and expensive ranking.
      • Code: CASE WHEN total_pm >= {pm_threshold} THEN 1 ELSE 0 END AS is_high_pm
Code
# ====================================================
# 1. Calculate Total Proper Motion & Find the Threshold
# ====================================================

# Calculate total proper motion (total_pm = sqrt(pmra^2 + pmdec^2))
df_motion = df_local.withColumn(
    "total_pm", 
    sqrt(pow(col("pmra"), 2) + pow(col("pmdec"), 2))
)

# Find the 99th percentile (Top 1%) proper motion value
# This value will be our threshold (e.g., 100 mas/yr)
pm_threshold = df_motion.agg(
    percentile_approx("total_pm", lit(0.99)).alias("threshold_value")
).collect()[0]["threshold_value"]

print(f">>> Calculated 99th Percentile Proper Motion Threshold: {pm_threshold:.2f} mas/yr")

# ====================================================
# 2. SQL Query: Prepare Data for Plotting
# ====================================================

# Create a temporary view for the motion-enhanced DataFrame
df_motion.createOrReplaceTempView("gaia_motion")

# The query selects necessary fields and flags the fast-moving stars
plot_query_with_motion = f"""
SELECT 
    source_id,
    bp_rp,
    -- Calculate Absolute Magnitude (Mg)
    (phot_g_mean_mag - 5 * LOG10(1000 / parallax) + 5) AS abs_mag_g,
    total_pm,
    
    -- Flag if the star is in the top 1% of motion
    CASE 
        WHEN total_pm >= {pm_threshold} THEN 1 
        ELSE 0 
    END AS is_high_pm
    
FROM gaia_motion
WHERE parallax > 0 
  AND (parallax / parallax_error) > 5  -- The Quality Cut
  AND phot_g_mean_mag IS NOT NULL 
"""

# Execute the query
df_plot_motion = spark.sql(plot_query_with_motion)

# Convert to Pandas for plotting
pdf_motion = df_plot_motion.toPandas()

print(f"Total stars prepared for plotting: {len(pdf_motion):,}")
print(f"Total high-PM stars flagged: {pdf_motion['is_high_pm'].sum():,}")
>>> Calculated 99th Percentile Proper Motion Threshold: 450.85 mas/yr
Total stars prepared for plotting: 540,859
Total high-PM stars flagged: 5,435

Visualization

A Layered Scatter Plot on the H-R Diagram was used (as seen in the executed code). This is a highly effective scientific visualisation that goes beyond the simple table proposed in the original plan. It plots: * Background: The bulk (slow-moving) population (purple). * Foreground: The high-velocity outliers (is_high_pm = 1) in a distinct colour (cyan), showing their position relative to the main stellar sequences.

Code
import matplotlib.pyplot as plt

# Filter the data into two subsets
pdf_slow = pdf_motion[pdf_motion['is_high_pm'] == 0]
pdf_fast = pdf_motion[pdf_motion['is_high_pm'] == 1]

# 1. Create the Plot Canvas
plt.figure(figsize=(8, 10))
plt.style.use('dark_background')

# 2. Plot the Bulk Population (Slow/Main Sequence)
# Use a high alpha (low transparency) colour to show the main density
plt.scatter(
    pdf_slow['bp_rp'], 
    pdf_slow['abs_mag_g'], 
    c='purple',         # Base colour
    s=0.5,              # Small size
    alpha=0.2,          # Very transparent to show density variation
    edgecolors='none', 
    label='Bulk Population (Low PM)'
)

# 3. OVERLAY the High-Velocity Stars
# Use a distinct, bright colour and larger size
plt.scatter(
    pdf_fast['bp_rp'], 
    pdf_fast['abs_mag_g'], 
    c='cyan',           # Stand-out colour
    s=5,                # Much larger size
    alpha=1.0,          # Fully opaque
    edgecolors='none',
    label=f'High-Velocity Outliers (Top 1% > {pm_threshold:.0f} mas/yr)'
)

# 4. Polish and Axes
plt.gca().invert_yaxis()  # Brighter stars (lower mag) go at the top
plt.title('HR Diagram Highlighting Kinematic Outliers (Top 1% Proper Motion)', fontsize=16)
plt.xlabel('Colour Index (BP-RP)')
plt.ylabel('Absolute Magnitude (M)')
plt.xlim(-1, 5)
plt.ylim(17, -5)
plt.legend(loc='upper right')
plt.grid(alpha=0.1)

plt.style.use('dark_background') # Dark background makes the colors pop
plt.show()
spark.stop()

Layered Scatter Plot on the H-R Diagram

Layered Scatter Plot on the H-R Diagram

Interpretation:

The separation of stars by Proper Motion (PM) provides a “motion-based” window into Galactic history that photometry (brightness/colour) alone cannot reveal.

Stellar Populations: Disk vs. Halo

  • The Bulk Population (Purple): These stars represent the Galactic Disk. They move in orderly, circular orbits around the Galactic Centre, similar to the Sun. Their relative motion to us is low, hence their “slow” proper motion.

  • High-Velocity Outliers (Cyan): These are largely Galactic Halo stars. They move on highly elliptical, random orbits that “dive” through the disk at hundreds of kilometres per second. When we see them, they appear as “speeding” outliers.

2.2 Yogi

2.2.1 Galactic Plane vs Halo

Code
from pyspark.sql import SparkSession
#from pyspark.sql.functions import * 
import pyspark.sql.functions as F
from pyspark.sql.window import Window
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Initialize Spark
spark = SparkSession.builder \
.appName("Member2_Galactic_Structure") \
.config("spark.driver.memory", "4g") \
.getOrCreate()

# 1. Load the Data
parquet_path = "../data/gaia_survey.parquet"

df = spark.read.parquet(parquet_path)


# describe the gaia_survey data
df.describe().show()
25/12/18 09:37:46 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
+-------+--------------------+--------------------+-------------------+--------------------+-------------------+-------------------+-------------------+------------------+------------------+------------------+
|summary|           source_id|                  ra|                dec|            parallax|     parallax_error|               pmra|              pmdec|   phot_g_mean_mag|             bp_rp|      teff_gspphot|
+-------+--------------------+--------------------+-------------------+--------------------+-------------------+-------------------+-------------------+------------------+------------------+------------------+
|  count|              844868|              844868|             844868|              844868|             844868|             844868|             844868|            844868|            821788|            719674|
|   mean|4.208655544660601...|   220.3738280138546|-14.661386665029692|    0.55071253992443|0.12661974095717216| -2.373963360005405|-3.0263002891849347|17.404869147894882|1.5300547396294117|  4878.39501996751|
| stddev|1.770170248935241...|   84.67372412021707|  38.84441601829229|  0.7697707648724794|0.08883984995128799| 7.0091972840050865|  6.990431755441395|1.4361723326971574|0.6075873126388648|1028.3654349007377|
|    min|      42159399217024|2.620729015336566...| -89.92247730288476|1.868718209532133E-7|        0.007789557|-377.74180694687686| -565.3362997837297|         3.0238004|       -0.54805183|          2739.997|
|    max| 6917515734718201472|   359.9982847741013|  89.77689025849239|   75.56887569382528|           1.519277|  649.0319386167508|  340.6398385516719|         18.999998|          7.822592|          37489.88|
+-------+--------------------+--------------------+-------------------+--------------------+-------------------+-------------------+-------------------+------------------+------------------+------------------+
Code
# 2. Register as a Temp View 
# This allows us to use SQL commands on the dataframe 'df'
df.createOrReplaceTempView("gaia_source")

print(">>> Executing Query 2.1: Galactic Plane vs Halo...")

# 3. The Spark SQL Query
# We use a CTE (Common Table Expression) named 'CalculatedData' to do the math first,
# and then the main query does the aggregation. 
query = """
    WITH CalculatedData AS (
        SELECT 
            source_id,
            -- Calculate Total Motion (Hypotenuse of pmra and pmdec)
            SQRT(POW(pmra, 2) + POW(pmdec, 2)) AS total_motion,
            
            -- Define Region using the 'DEC Proxy' method
            CASE 
                WHEN ABS(dec) < 15 THEN 'Galactic Plane' 
                ELSE 'Galactic Halo' 
            END AS region
        FROM gaia_source
    )
    
    -- Final Aggregation
    SELECT 
        region,
        ROUND(AVG(total_motion), 2) AS avg_speed,
        ROUND(STDDEV(total_motion), 2) AS stddev_speed,
        COUNT(*) AS star_count
    FROM CalculatedData
    GROUP BY region
    ORDER BY avg_speed DESC
"""

# 4. Run the query
sql_results = spark.sql(query)

# 5. Show results
sql_results.show()
>>> Executing Query 2.1: Galactic Plane vs Halo...
+--------------+---------+------------+----------+
|        region|avg_speed|stddev_speed|star_count|
+--------------+---------+------------+----------+
| Galactic Halo|     7.03|        7.73|    691873|
|Galactic Plane|     6.96|        8.96|    152995|
+--------------+---------+------------+----------+

Analysis

The objective of this query was to identify stellar kinematics by comparing the proper motion of stars in the dense Galactic Disk and Galactic Halo.

  • Metric: We calculated the “Total Proper Motion” (\(\mu\)) for each star by combining its two components: \(\mu = \sqrt{\texttt{pmra}^2 + \texttt{pmdec}^2}\).
  • Segmentation: Due to dataset constraints, we utilized Declination (dec) as a proxy for Galactic Latitude. We defined the “Galactic Plane” as the equatorial band (\(|dec| < 15^{\circ}\)) and the “Halo” as the high-latitude regions (\(|dec| \ge 15^{\circ}\)).

Finding A: The Problem “Missed Galaxy”(Star Count)

  • the simple Reason is that the Milky Way is tilted.
  • Our query only sees the flat strip across the middle (dec between -15 and 15 degrees).Because of that our “Flat strip” missed the biggest parts of the galaxy.

Finding B: The Velocity Variation

  • We Expected the “Halo” to have a higher velocity than the “Plane”.Instead the “Disk” have the biggest range (8.96 vs 7.73).

The simple reason is the distance changes how speed looks.

  • The Disk: Contains many stars that are close to Earth. Because they are close, their speeds look dramatic and varied to our camera.

  • The Halo: Stars are incredibly far away. Even if they are moving fast, their distance makes them all appear to move slowly and steadily, leading to a “lower” measurement.

Visualization

Code
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# 1. Take a random 1% sample (Crucial for performance)
pdf_subset = df.select("ra", "dec").sample(fraction=0.01, seed=42).toPandas()

# 2. Re-create the Region logic (using numpy to avoid errors)
pdf_subset['Region'] = pdf_subset['dec'].apply(lambda x: 'Galactic Plane' if np.abs(x) < 15 else 'Galactic Halo')

# 3. SET THE THEME: Dark Background for a "Space" look
plt.style.use('dark_background')
plt.figure(figsize=(12, 7))

# 4. Plot the "Halo" (Background Stars)
# We plot these first in cool blue so they look "distant"
sns.scatterplot(
    data=pdf_subset[pdf_subset['Region'] == 'Galactic Halo'], 
    x='ra', 
    y='dec', 
    color='cornflowerblue', 
    s=5,           # Small dots
    alpha=0.3,     # Faint transparency
    edgecolor=None,
    label='Galactic Halo (Sparse)'
)

# 5. Plot the "Plane" (The Disk)
# We plot these on top in bright Gold to represent the dense star field
sns.scatterplot(
    data=pdf_subset[pdf_subset['Region'] == 'Galactic Plane'], 
    x='ra', 
    y='dec', 
    color='#FFD700', # Gold color
    s=10,            # Slightly larger dots
    alpha=0.4,       # Brighter
    edgecolor=None,
    label='Galactic Plane (Dense)'
)

# Draw the cut-off lines
plt.axhline(15, color='white', linestyle='--', linewidth=1, alpha=0.5)
plt.axhline(-15, color='white', linestyle='--', linewidth=1, alpha=0.5)

# Add text labels on the graph
plt.text(180, 0, "Milky Way Disk\n(High Density)", color='orangered', 
         ha='center', va='center', fontsize=12, )

plt.text(180, 60, "Galactic Halo\n(Low Density)", color='cornflowerblue', 
         ha='center', va='center', fontsize=10)

# 7. Final Polish
plt.title("Spatial Structure: The 'Flat' Disk vs. The 'Round' Halo", fontsize=14, color='white')
plt.xlabel("Right Ascension (Longitude)", fontsize=12)
plt.ylabel("Declination (Latitude)", fontsize=12)
plt.legend(loc='upper right', facecolor='black', edgecolor='white')
plt.grid(False) # Turn off grid to look more like space

# Astronomers view the sky looking "up", so we invert the X-axis
plt.gca().invert_xaxis()

plt.show()

Galactic Plane vs Halo

Galactic Plane vs Halo

2.2.2 Star Density Sky Map

Code
print(">>> Executing Query 2.2: Star Density Sky Map")

query_density = """
    WITH DensityBins AS (
        SELECT 
            -- 1. Spatial Binning (The 'Grid')
            -- We divide by 2, floor it to remove decimals, then multiply by 2
            -- This snaps every star to the nearest even number grid line (0, 2, 4...)
            FLOOR(ra / 2) * 2 AS ra_bin,
            FLOOR(dec / 2) * 2 AS dec_bin,
            
            -- 2. Aggregation (Counting stars in that grid square)
            COUNT(*) AS star_count
        FROM gaia_source
        GROUP BY 1, 2  -- Group by the first two columns (ra_bin, dec_bin)
    ),
    
    RankedRegions AS (
        SELECT 
            *,
            -- Rank the bins from most populated (1) to least populated
            RANK() OVER (ORDER BY star_count DESC) as density_rank
        FROM DensityBins
    )
    
    -- 4. Final Result: Top 5 Densest Regions
    SELECT * FROM RankedRegions 
    WHERE density_rank <= 5
"""

# 3. Run and Show
sql_density_results = spark.sql(query_density)
sql_density_results.show()
>>> Executing Query 2.2: Star Density Sky Map
25/12/18 09:37:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/18 09:37:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/18 09:37:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+------+-------+----------+------------+
|ra_bin|dec_bin|star_count|density_rank|
+------+-------+----------+------------+
|   270|    -30|      2135|           1|
|   272|    -30|      2096|           2|
|   268|    -30|      2036|           3|
|   272|    -28|      2011|           4|
|   274|    -28|      1762|           5|
+------+-------+----------+------------+
25/12/18 09:37:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/18 09:37:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/18 09:37:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/18 09:37:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.

Analysis

The goal of this query was to create the “Stellar Density Map” to identify the most populated region in the sky.

  • Spatial Binning Strategy: We have utilized a qantization approach to divide the continuous sky into discrete grid. FLOOR(coordinate / 2) * 2 to both Right Ascension (ra) and Declination (dec), we grouped stars into \(2^{\circ} \times 2^{\circ}\) spatial bins.
  • Analytical Complexity: To Rank this density, we have incorporated the pyspark Window Function (RANK() OVER (ORDER BY star_count DESC)).

Critical Analysis:

  1. The results has the perfect validation of the binning algorithm and identified Galactic centre.
  2. Astronomical Validation: The coordinates returned (\(RA \approx 270^{\circ}\), \(Dec \approx -30^{\circ}\)) correspond precisely to the constellation Sagittarius.
  • The official coordinates of Sagittarius A* (the supermassive black hole at the center of the Milky Way)re \(RA \approx 266^{\circ}\), \(Dec \approx -29^{\circ}\).

Interpretation:

The high star count in these bins conform that confirm that we are looking through the galactic plate and the centre bulge.

Visualization

Code
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from pyspark.sql.functions import col

# 1. Reuse the "DensityBins" logic from Query 2.2, but save it as a real Dataframe
density_bins = spark.sql("""
    SELECT 
        FLOOR(ra / 2) * 2 AS ra_bin, 
        FLOOR(dec / 2) * 2 AS dec_bin,
        COUNT(*) as local_density
    FROM gaia_source 
    GROUP BY 1, 2
""")

# We give every star a new property: "How many neighbors do I have?"
# We calculate the bins on the fly for the join
df_with_density = df.withColumn("ra_bin", F.floor(col("ra") / 2) * 2) \
                    .withColumn("dec_bin", F.floor(col("dec") / 2) * 2) \
                    .join(density_bins, ["ra_bin", "dec_bin"])

# This is enough to look like "every star" to the human eye without crashing.
pdf_visual = df_with_density.sample(fraction=0.05, seed=42).select("ra", "dec", "local_density").toPandas()

# 4. The "Glowing" Scatter Plot
plt.style.use('default')
plt.figure(figsize=(14, 8))

# We sort by density so the bright stars are plotted ON TOP of the dark ones
pdf_visual = pdf_visual.sort_values("local_density")

scatter = plt.scatter(
    pdf_visual['ra'], 
    pdf_visual['dec'], 
    c=pdf_visual['local_density'], # Color by density
    cmap='magma',                # Magma/Inferno = glowing fire effect
    s=2,                           # Tiny dots 
    alpha=0.8,                     # High opacity to make them pop
    edgecolors='none'              # No borders
)

# 5. Add a Color Bar (Legend)
cbar = plt.colorbar(scatter)
cbar.set_label('Stellar Density (Stars per bin)', rotation=270, labelpad=20, color='Black')
cbar.ax.yaxis.set_tick_params(color='Black')
plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='Black')

# 6. Styling
plt.title("The Milky Way: Stellar Density Visualization", fontsize=18, color='Black' )
plt.xlabel("Right Ascension", fontsize=12, color='Black')
plt.ylabel("Declination", fontsize=12, color='Black')
plt.gca().invert_xaxis() # Astronomical standard
plt.grid(False)

plt.show()

Star Density Sky Map

Star Density Sky Map

2.2.3 Parallax Error vs Brightness

Code
print(">>> Executing Query 2.3: Parallax Error vs. Brightness...")

query_quality = """
    WITH QualityMetrics AS (
        SELECT 
            -- 1. Create Brightness Bins (0-21)
            -- FLOOR groups '10.2' and '10.9' into bin '10'
            FLOOR(phot_g_mean_mag) AS mag_bin,
            
            -- Pass through the raw data we need
            parallax_error,
            parallax
        FROM gaia_source
        WHERE parallax > 0  -- Filter out bad data to avoid division by zero
    )
    
    -- 2. Aggregation to find the Error Trend
    SELECT 
        mag_bin AS magnitude,
        
        -- A. Count how many stars are in this brightness range
        COUNT(*) AS star_count,
        
        -- B. Average Absolute Error (Raw uncertainty)
        ROUND(AVG(parallax_error), 4) AS avg_raw_error,
        
        -- C. Average Relative Error (The "Percentage" uncertainty)
        -- Formula: Error / Total Signal
        ROUND(AVG(parallax_error / parallax), 4) AS avg_relative_error
        
    FROM QualityMetrics
    WHERE mag_bin > 0 AND mag_bin < 22 -- Focus on the valid main range
    GROUP BY mag_bin
    ORDER BY mag_bin ASC -- Sort from Bright -> Dim
"""

# 3. Run and Show
quality_results = spark.sql(query_quality)
quality_results.show(25) # Show 25 rows to see the full range
>>> Executing Query 2.3: Parallax Error vs. Brightness...
+---------+----------+-------------+------------------+
|magnitude|star_count|avg_raw_error|avg_relative_error|
+---------+----------+-------------+------------------+
|        3|         1|       0.1317|            0.0438|
|        4|         1|       0.0749|            0.0146|
|        5|         7|       0.0586|             0.023|
|        6|        27|       0.0426|            0.0125|
|        7|        61|       0.0357|             0.013|
|        8|       197|       0.0401|            0.0207|
|        9|       488|       0.0263|             0.017|
|       10|      1304|       0.0264|            0.0229|
|       11|      3007|       0.0265|            0.0299|
|       12|      7054|       0.0246|            0.0487|
|       13|     15489|       0.0239|            0.0655|
|       14|     32688|       0.0297|            0.2688|
|       15|     66266|       0.0411|            1.2044|
|       16|    125846|       0.0633|            0.4907|
|       17|    223654|        0.105|            1.4188|
|       18|    368778|       0.1929|            4.8654|
+---------+----------+-------------+------------------+

Analysis

The objective of this query was to evaluate the steller technique’s depandibity at various star magnitudes. Before entering to the machine learning part, it is essetional to indentify the “Signal-to-Noice” ratios.

  • Binning Strategy: We grouped stars by their Apparent Magnitude (phot_g_mean_mag) into integer bins (e.g., Magnitude 10, 11, 12…).
  • For each bin we have clacualted:
    • Average Absolute Error AVG(parallax_error)
    • Average Relative Error AVG(parallax_error / parallax)
  • Statistics:
    • Bright stars - (Mag < 13 ) High photon counts result in high-precision centroids, leading to low parallax error.
    • Dim stars - (Mag > 13 )Lower signal-to-noise ratios should result in exponentially increasing errors.

Visualization

Code
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# 1. Convert Spark Result to Pandas
pdf_quality = quality_results.toPandas()

# 2. Setup the Plot (Dual Axis)
# We want to show Star Count (Bars) AND Error Rate (Line) on the same chart
fig, ax1 = plt.subplots(figsize=(12, 6))

# 3. Plot A: Star Count (The Histogram) on Left Axis
# This shows where most of our data lives
sns.barplot(
    data=pdf_quality, 
    x='magnitude', 
    y='star_count', 
    color='cornflowerblue', 
    alpha=0.3, 
    ax=ax1,
    label='Star Count'
)
ax1.set_ylabel('Number of Stars (Log Scale)', color='blue')
ax1.tick_params(axis='y', labelcolor='blue')
ax1.set_yscale('log') # Log scale because star counts vary wildly (1 to 300,000)

# 4. Plot B: The Parallax Error (AVG Error) on Right Axis
ax2 = ax1.twinx()

sns.lineplot(
    data=pdf_quality, 
    x=ax1.get_xticks(), # Align line with bars
    y='avg_raw_error',  # AVG(parallax_error)
    color='red', 
    marker='o', 
    linewidth=2,
    ax=ax2,
    label='Avg Parallax Error (mas)'
)
ax2.set_ylabel('Avg Parallax Error (milliarcseconds)', color='red')
ax2.tick_params(axis='y', labelcolor='red')

# 5. Titles and Layout
plt.title("Data Quality Audit (Error Rate vs. Brightness)", fontsize=14)
ax1.set_xlabel("Apparent Magnitude (Lower = Brighter)", fontsize=12)
plt.grid(True, linestyle=':', alpha=0.5)

plt.tight_layout()
plt.show()

Parallax Error vs Brightness

Parallax Error vs Brightness
Code
spark.stop()

2.3 Jayrup (Exotic Star Hunting)

Finding rare and interesting stellar objects

2.3.1 White Dwarf Candidates

To find White Dwarfs, which are the hot, dense cores of dead stars. They are very hot but very dim. They are the remains of dead stars that have not yet completely cooled down, they start out bright-blue (top-right) and as they age, they turn dim-red (bottom-right). Scientists use this to estimate the age of the universe as their cooling down rate is fairly stable.

Importing Libraries

Code
import numpy as np
import seaborn as sns
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.collections import LineCollection

# 1. Setup Spark (Use existing data)
spark = SparkSession.builder \
    .appName("WhiteDwarf_Hunter") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

Loading Data

Code
# data for the nearest 100 parsecs sources
df = spark.read.parquet("../data/gaia_100pc.parquet")
df.createOrReplaceTempView("gaia")

Color-Magnitude Method

We’re isolating white dwarfs within 100 parsecs using Gaia’s precise astrometry and photometry. White dwarfs occupy a distinct region in the HR diagram:

  • Faint but not too faint to avoid noisy detections (\(M_G\) \(10–15\))
  • Blue-to-yellow colors where the cooling sequence is densest (\(BP–RP < 1.0\))
  • Excluding outliers to remove rare hot subdwarfs (\(BP–RP ≥ -0.5\))

We calculate the Absolute magnitude \(M_G\) with

\[ M_G = m + 5 \log_{10}(parallax) - 10 \]

Code
# Find the White Dwarfs candidates using Color-Magnitude method
df_wd = spark.sql("""
    WITH candidates AS (
        SELECT *,
               phot_g_mean_mag + 5 * LOG10(parallax) - 10 AS absolute_magnitude
        FROM gaia
        WHERE parallax > 0
          AND bp_rp >= -0.5
          AND bp_rp < 1.0
          AND phot_g_mean_mag IS NOT NULL
          AND bp_rp IS NOT NULL
    )
    SELECT *,
           CASE
             WHEN absolute_magnitude > 10 AND absolute_magnitude < 15
             THEN 'White Dwarf'
             ELSE 'Main Sequence / Other'
           END AS type
    FROM candidates
""")

# Trigger computation and cache
df_wd.cache().count()

# Count WDs
wd_count = df_wd.filter(col("type") == "White Dwarf").count()
print(f">> FOUND: {wd_count} White Dwarfs using Color-Magnitude method.")
>> FOUND: 12580 White Dwarfs using Color-Magnitude method.
Code
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from pyspark.sql.functions import col

# === 1. OPTIMIZED DATA LOADING ===
# Only pull necessary columns + filter in Spark
wd_pandas = (
    df_wd.filter(col("type") == "White Dwarf")
    .select("bp_rp", "absolute_magnitude") 
    .filter(
        (col("bp_rp").between(-0.5, 1.5)) & 
        (col("absolute_magnitude").between(8, 16)) &
        col("bp_rp").isNotNull() &
        col("absolute_magnitude").isNotNull()
    )
    .toPandas()
)

spark.stop()

# === 2. VECTORIZED DENSITY CALCULATION ===
x = wd_pandas['bp_rp'].values
y = wd_pandas['absolute_magnitude'].values

# Precompute bin edges 
x_edges = np.linspace(-0.5, 1.5, 151)
y_edges = np.linspace(8, 16, 151)

# Compute histogram
H, _, _ = np.histogram2d(x, y, bins=[x_edges, y_edges])

# Digitize using precomputed edges
x_inds = np.digitize(x, x_edges[1:-1])
y_inds = np.digitize(y, y_edges[1:-1])

# Get densities directly from histogram array
z = H[x_inds, y_inds]

# === 3. SORTING OPTIMIZATION ===
# Use argsort on density (z) - but only if >10k points
if len(z) > 10000:
    idx = np.argpartition(z, -10000)[-10000:]  # Keep only top 10k densest points
    x_sorted, y_sorted, z_sorted = x[idx], y[idx], z[idx]
else:
    idx = z.argsort()
    x_sorted, y_sorted, z_sorted = x[idx], y[idx], z[idx]

# === 4. PLOTTING ===
plt.figure(dpi=120)

# Use scatter with pre-sorted points (densest on top)
sc = plt.scatter(
    x_sorted, y_sorted,
    c=z_sorted,
    s=1.5,
    cmap='gist_heat',
    norm=colors.LogNorm(vmin=1, vmax=np.max(H)),  # Precomputed vmax
    alpha=0.95,
    edgecolors='none',
    rasterized=True  # converts to bitmap for huge speedup
)

# === 5. ASTRONOMY POLISH ===
ax = plt.gca()
ax.invert_yaxis()

ax.set_title("Gaia DR3: White Dwarf Cooling Sequence", fontsize=14)
ax.set_xlabel("Gaia BP–RP colour", fontsize=12)
ax.set_ylabel("Gaia G absolute magnitude", fontsize=12)

# Colorbar with precomputed norm
cbar = plt.colorbar(sc, pad=0.02)
cbar.set_label('Stars per bin (log scale)', rotation=270, labelpad=20)

ax.grid(True, alpha=0.2, linewidth=0.5)
plt.tight_layout(pad=1.5)  # Faster than default layout

plt.show()
plt.close()  # Free memory immediately

What This Plot Reveals

This is the white dwarf cooling sequence the evolutionary path of dead stars in our cosmic neighborhood. Here’s what it tells us:

  1. The Diagonal Band
    • White dwarfs start hot and blue (top-left: \(BP-RP ≈ -0.3, M_G ≈ 10\))
    • They cool and redden over billions of years, moving down and right (bottom-right: \(BP-RP ≈ 1.0, M_G ≈ 15\))
  2. The Color Gradient
    • Bright orange/white regions: High density of white dwarfs (common evolutionary stages)
    • Dark red regions: Fewer white dwarfs (rare or short-lived phases)
  3. The Smooth Curve
    • This sequence is a stellar “fossil record”—it reveals how long white dwarfs have been cooling
    • The gap at top-left (BP-RP < -0.2) shows very hot white dwarfs are rare (they cool quickly)
    • The smooth curve confirms white dwarfs cool predictably—like cosmic thermometers
  4. The “Bifurcation”
    • The population is segregated : If we look closely at the diagonal band, it isn’t a single smear; it’s split into two distinct, parallel “ridges” or tracks. White Dwarfs are not a homogeneous group. This split typically represents a difference in atmospheric composition.
      • Track A (Top/Blue ridge): Likely stars with Hydrogen-rich atmospheres (DA white dwarfs). Hydrogen is lighter and more opaque, acting as a blanket that keeps heat in differently than Helium.
      • Track B (Bottom/Red ridge): Likely stars with Helium-rich atmospheres (DB white dwarfs).

Key insight: The densest part (bright orange) is where most white dwarfs “spend” their lives—proving they cool slowly over billions of years. This plot is why astronomers call white dwarfs “cosmic clocks” for measuring the age of our galaxy.

2.3.2 Red Giant Candidates

Red Giants are old, dying stars. They are very cool but very bright, hence they are on the top-right of the HR-diagram.

Code
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors

# 1. Setup Spark 
spark = SparkSession.builder \
    .appName("RedGiant_DensityPlot") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

# Load only necessary columns
df = spark.read.parquet("../data/gaia_survey.parquet")


# create temp view
df.createOrReplaceTempView("gaia_source")

We use two seperate subqueries for the visualization. One to get the general idea and show the relative postion of the red-giants on the HR-daigram and the other to zoom it so we can analyze them better.

Code
df_all = spark.sql("""
    SELECT 
        bp_rp,
        phot_g_mean_mag + 5 * LOG10(parallax) - 10 AS absolute_magnitude
    FROM gaia_source
    WHERE parallax > 0
      AND parallax/parallax_error > 10
      AND bp_rp BETWEEN -0.5 AND 3.5
      -- AND (phot_g_mean_mag + 5 * LOG10(parallax) - 25) < 12
""")

df_rg = spark.sql("""
    SELECT 
        bp_rp,
        phot_g_mean_mag + 5 * LOG10(parallax) - 10 AS absolute_magnitude,
        parallax/parallax_error AS parallax_snr
    FROM gaia_source
    WHERE parallax > 0
      AND parallax/parallax_error > 10  -- Good parallax quality
      AND bp_rp BETWEEN 0.7 AND 2.5     -- Focus on RGB color range
      AND phot_g_mean_mag + 5 * LOG10(parallax) - 10 < 4.0  -- Bright giants
      AND bp_rp IS NOT NULL
      AND phot_g_mean_mag IS NOT NULL
""")

# Convert to pandas for plotting
all_pandas = df_all.toPandas()
rg_pandas = df_rg.toPandas()
spark.stop()
Code
# print no. of stars found
# print(f"Found {len(all_pandas)} stars in the survey.")

hb = plt.hexbin(
    x=all_pandas['bp_rp'], y=all_pandas['absolute_magnitude'],
    gridsize=500,              # High res = no blocky look
    extent=[-0.5, 3.5, -5, 12], # Fixed window for consistency
    norm=colors.LogNorm(),     # Essential: Compresses the dynamic range
    cmap='inferno',            # Perceptually uniform (Blue/Black -> Red -> Yellow)
    mincnt=1                   # Don't plot empty space
)

# Anotations
ax = plt.gca()
ax.invert_yaxis()

# A. The Main Sequence (High Density)
ax.text(1.6, 10.0, 'Main Sequence', color='purple', fontsize=10, 
        ha='right', rotation=-30)


# C. The True Evolutionary Path (Sub-Giant Branch)
# This arrow follows the curve, not a straight line
# Coordinates: Turn-off point -> Base of RGB
ax.annotate('', xy=(1.0, 0), xytext=(0.2,- 0.2),
            arrowprops=dict(arrowstyle='->', lw=2, color='cyan', connectionstyle="arc3,rad=-0.2"))
ax.text(0,0, 'Giant\nBranch', color='purple', fontsize=9, ha='center')

# Polish
ax.set_xlabel("Gaia BP–RP colour", fontsize=12)
ax.set_ylabel("Absolute Magnitude ($M_G$)", fontsize=12)

# Colorbar
cb = plt.colorbar(hb, pad=0.02)
cb.set_label('Star Density (Log Scale)', rotation=270, labelpad=20)

plt.tight_layout()
plt.show()

# print no. of stars found
# print(f"Found {len(rg_pandas)} stars in the survey.")

hb = plt.hexbin(
    x=rg_pandas['bp_rp'], y=rg_pandas['absolute_magnitude'],
    gridsize=500,              # High res = no blocky look
    # extent=[-0.5, 3.5, -5, 12], # Fixed window for consistency
    norm=colors.LogNorm(),     # Essential: Compresses the dynamic range
    cmap='inferno',            # Perceptually uniform (Blue/Black -> Red -> Yellow)
    mincnt=1                   # Don't plot empty space
)

# Anotations
ax = plt.gca()
ax.invert_yaxis()

# A. The Main Sequence (High Density)
# ax.text(1.6, 3, 'Main Sequence', color='white', fontsize=10, 
#         ha='right', rotation=-30)

plt.text(0.8, 3.5, 'Main Sequence', 
         color='white', fontsize=8, fontweight='bold',
         bbox=dict(facecolor='none', edgecolor='white', alpha=0.5))

# C. The True Evolutionary Path (Sub-Giant Branch)
# This arrow follows the curve, not a straight line
# Coordinates: Turn-off point -> Base of RGB
# ax.annotate('', xy=(1.0, 0), xytext=(0.2,- 0.2),
#             arrowprops=dict(arrowstyle='->', lw=2, color='cyan', connectionstyle="arc3,rad=-0.2"))
# ax.text(0,0, 'Giant\nBranch', color='purple', fontsize=9, ha='center')

# Polish
ax.set_xlabel("Gaia BP–RP colour", fontsize=12)
ax.set_ylabel("Absolute Magnitude ($M_G$)", fontsize=12)

# Colorbar
cb = plt.colorbar(hb, pad=0.02)
cb.set_label('Star Density (Log Scale)', rotation=270, labelpad=20)

plt.tight_layout()
plt.show()

Zoomed out

Zoomed out

Zoomed in

Zoomed in

Gaia DR3: Stellar Density Map

What This Plot Reveals

This is the Red Giant Branch (RGB), the evolutionary path of aging, low- to intermediate-mass stars as they exhaust hydrogen in their cores and begin shell burning. Here’s what it tells us:

  1. The Bright, Red Clump
    • Red giants are cool (\(BP–RP ≈ 1.0–2.5\)) but extremely luminous (\(M_G < 0\)), placing them in the upper-right of the HR diagram.
    • The densest region—the Red Clump (\(around M_G ≈ −1, BP–RP ≈ 1.3\))—marks stars stably burning helium in their cores. This acts as a “standard candle” for distance measurements.
  2. The Ascending Giant Branch
    • Stars move upward (brighter) and slightly redder as their outer envelopes expand after leaving the Main Sequence.
    • The smooth, curved track reflects predictable stellar evolution governed by mass and composition.
  3. Low-Mass vs. Upper RGB
    • At fainter magnitudes (M_G ≈ 2–4), we see lower-mass giants just starting their ascent.
    • The brightest, reddest stars (M_G < −1) are near the tip of the RGB, where helium ignition occurs in a dramatic flash for low-mass stars.
  4. Why It Matters
    • The RGB is a stellar aging sequence: its shape and density reveal the star formation history of our galactic neighborhood.
    • Because red giants are bright, they’re visible across vast distances—making them key tracers of galactic structure.

TLDR: this plot shows stars in their retirement, glowing brightly as they near the end of their lives—before shedding their outer layers to become white dwarfs.

2.3.3 Co-Moving Pair Search (Binary Candidates)

Binary stars are gravitationally bound systems sharing common motion through space. We can detect them by finding pairs with:

  • Similar positions (angular proximity)
  • Similar distances (parallax values)
  • Similar proper motions (space velocity vectors)
Code
import pyspark.sql.functions as F
# initialize spark session
spark = SparkSession.builder \
    .appName("Co-Moving_PairSearch") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

df = spark.read.parquet("../data/gaia_100pc.parquet")

# Add coarse sky bins (1° × 1°) - simple but effective
df = df.withColumn("ra_bin", F.floor(F.col("ra") / 1)) \
       .withColumn("dec_bin", F.floor(F.col("dec") / 1))
df.createOrReplaceTempView("stars")
Code
pairs = spark.sql("""
    SELECT A.source_id as id1, B.source_id as id2,
           A.ra, A.dec, A.parallax, 
           A.pmra, A.pmdec,
           B.ra as ra2, B.dec as dec2, B.parallax as plx2,
           B.pmra as pmra2, B.pmdec as pmdec2
    FROM stars A JOIN stars B
      ON A.ra_bin = B.ra_bin AND A.dec_bin = B.dec_bin
      AND A.source_id < B.source_id
    WHERE ABS(A.parallax - B.parallax) < 1      -- Same distance (within 1 mas)
      AND ABS(A.pmra - B.pmra) < 1              -- Same motion
      AND ABS(A.pmdec - B.pmdec) < 1
""")
Code
# Distance to each star (parsecs)
pairs = pairs.withColumn("dist_pc", 1000 / F.col("parallax"))

# Angular separation (degrees) and physical separation (AU)
pairs = pairs.withColumn(
    "ang_sep_deg", 
    F.degrees(F.acos(F.sin(F.radians("dec")) * F.sin(F.radians("dec2")) + 
                    F.cos(F.radians("dec")) * F.cos(F.radians("dec2")) * 
                    F.cos(F.radians("ra") - F.radians("ra2"))))
).withColumn("sep_au", F.col("ang_sep_deg") * 3600 * F.col("dist_pc"))

# Keep only likely physical pairs
binaries = pairs.filter(F.col("sep_au") < 10000)
print(f">> Found: {binaries.count()} candidate binaries")
[Stage 3:>                                                        (0 + 14) / 15][Stage 3:===========>                                             (3 + 12) / 15]
>> Found: 5129 candidate binaries
                                                                                
Code
# Get photometry for plotting (join back to original data)
binaries_with_phot = binaries.alias("p").join(
    df.select("source_id", "bp_rp", "phot_g_mean_mag").alias("phot"),
    F.col("p.id1") == F.col("phot.source_id")
).join(
    df.select("source_id", "bp_rp", "phot_g_mean_mag").alias("phot2"),
    F.col("p.id2") == F.col("phot2.source_id")
).select(
    "p.*", 
    F.col("phot.bp_rp").alias("color1"),
    F.col("phot.phot_g_mean_mag").alias("mag1"),
    F.col("phot2.bp_rp").alias("color2"),
    F.col("phot2.phot_g_mean_mag").alias("mag2")
)

# Calculate absolute magnitudes

plot_df = binaries_with_phot.toPandas()
plot_df['abs_mag1'] = plot_df['mag1'] + 5*np.log10(plot_df['parallax']) - 10
plot_df['abs_mag2'] = plot_df['mag2'] + 5*np.log10(plot_df['plx2']) - 10
[Stage 10:>                                                       (0 + 14) / 15][Stage 10:===>                                                    (1 + 14) / 15][Stage 10:=======>                                                (2 + 13) / 15][Stage 10:================================================>       (13 + 2) / 15]                                                                                
Code
# Analyze the distribution
print("\n=== Binary Separation Distribution ===")
print(f"Median separation: {plot_df['sep_au'].median():.1f} AU")
print(f"Mean separation: {plot_df['sep_au'].mean():.1f} AU")
print(f"Closest pair: {plot_df['sep_au'].min():.1f} AU")
print(f"Widest pair: {plot_df['sep_au'].max():.1f} AU")

# Classification by separation
plot_df['binary_type'] = np.select([
    plot_df['sep_au'] < 100,
    plot_df['sep_au'] < 1000,
    plot_df['sep_au'] < 10000
], ['Close', 'Intermediate', 'Wide'], 'Very Wide')

=== Binary Separation Distribution ===
Median separation: 1685.8 AU
Mean separation: 2738.5 AU
Closest pair: 31.2 AU
Widest pair: 9994.3 AU
Code
# Plot: Sample 100 pairs for clarity
sample = plot_df.sample(n=min(100, len(plot_df)))


plt.gca().invert_yaxis()
plt.xlabel("Gaia BP-RP Color")
plt.ylabel("Absolute Magnitude ($M_G$)")

# Draw connecting lines
lines = [[(r['color1'], r['abs_mag1']), (r['color2'], r['abs_mag2'])] 
         for _, r in sample.iterrows()]
lc = LineCollection(lines, colors='gray', alpha=0.3, linewidths=0.5)
plt.gca().add_collection(lc)

# Plot the stars
plt.scatter(sample['color1'], sample['abs_mag1'], s=30, c='skyblue', label='Star A')
plt.scatter(sample['color2'], sample['abs_mag2'], s=30, c='orange', label='Star B')

plt.legend()
plt.tight_layout()
plt.show()

# histogram
plt.hist(plot_df['sep_au'], bins=50, log=True, color='steelblue', alpha=0.7)
plt.axvline(100, color='red', linestyle='--', label='Close (<100 AU)')
plt.axvline(1000, color='orange', linestyle='--', label='Intermediate')

# set labels

plt.xlabel("Physical Separation (AU)")
plt.ylabel("Count (log scale)")
plt.legend()
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

Binary Systems on HR Diagram

Binary Systems on HR Diagram

Binary Separation Distribution

Binary Separation Distribution

Binary stars within 100 parsecs

Code
spark.stop()

3 ML-Models

In this section, we will look at the machine learning models that we have trained and evaluated on the Gaia Dataset.

3.1 Jasmi

Two supervised learning models were implemented to classify Gaia stars into hot and cool populations. Logistic Regression was used as a baseline linear classifier, while Random Forest was employed to capture non-linear relationships between stellar colour, brightness, and distance. The performance of both models was evaluated and compared using standard classification metrics.

Code
# Importing all the required libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, log10, lit, count, when, sqrt, pow, percentile_approx
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.functions import udf

from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier, LogisticRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.ml.linalg import VectorUDT

import os

import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

from matplotlib import colors
Code
# ====================================================
# SETUP & LOAD
# ====================================================
# Initialise Spark session once
spark = SparkSession.builder \
    .appName("Gaia_HR_Analysis") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

# Define Paths
survey_path = "../data/gaia_survey.parquet"
local_path = "../data/gaia_100pc.parquet"

# Load Datasets
print(">>> LOADING DATASETS...")
df_survey = spark.read.parquet(survey_path)
df_local = spark.read.parquet(local_path)
>>> LOADING DATASETS...

3.1.1 Random Forest Classifier

Models will be trained using the local Gaia dataset (within 100 pc) to ensure reliable temperature estimates and clearer stellar population separation when classifying stars into hot and cool categories.

Code
df_labeled = (
    df_local
    .filter(col("teff_gspphot").isNotNull())
    .withColumn(
        "label",
        when(col("teff_gspphot") <= 5200, 0).otherwise(1)
    )
)

A binary label was created to support supervised classification. Stars with an effective temperature (teff_gspphot) of ≤ 5200 K were classified as Cool (0), while stars with temperatures > 5200 K were classified as Hot (1). Rows with missing temperature values were removed to ensure valid model training. This threshold represents a physically meaningful division between cooler and hotter stellar populations and produces a label format compatible with Spark MLlib classifiers.

Code
print("\n>>> LABEL DISTRIBUTION IN LABELED DATASET")
df_labeled.groupBy("label").count().show()

>>> LABEL DISTRIBUTION IN LABELED DATASET
+-----+-----+
|label|count|
+-----+-----+
|    1|25192|
|    0|92886|
+-----+-----+

The labeled dataset is imbalanced, with 92,886 stars (79%) classified as Cool and 25,192 stars (21%) classified as Hot. This imbalance reflects the physical reality of the local stellar population, where cooler, lower-mass stars are significantly more common than hot, massive stars. To address this class imbalance during model training, techniques such as class weighting or resampling can be applied to ensure the classifier learns to identify both classes effectively..

Code
# ----------------------------------------------------
# Select features and label
# ----------------------------------------------------
feature_cols = [
    "bp_rp",
    "phot_g_mean_mag",
    "parallax"
]

df_model = (
    df_labeled
    .select(feature_cols + ["label"])
    .dropna()
)

# ----------------------------------------------------
# Handle class imbalance using class weights
# ----------------------------------------------------
# Compute class counts
label_counts = df_model.groupBy("label").count().collect()
total = sum(row["count"] for row in label_counts)

weights = {
    row["label"]: total / (2 * row["count"])
    for row in label_counts
}

df_model = df_model.withColumn(
    "weight",
    when(col("label") == 0, weights[0]).otherwise(weights[1])
)

# ----------------------------------------------------
# Assemble feature vector
# ----------------------------------------------------
assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features"
)

df_model = assembler.transform(df_model).select(
    "features", "label", "weight"
)

df_model.show(5)
+--------------------+-----+------------------+
|            features|label|            weight|
+--------------------+-----+------------------+
|[2.2142333984375,...|    0|0.6356085482047693|
|[3.34372138977050...|    0|0.6356085482047693|
|[2.82325744628906...|    0|0.6356085482047693|
|[1.90689754486083...|    0|0.6356085482047693|
|[2.55687236785888...|    0|0.6356085482047693|
+--------------------+-----+------------------+
only showing top 5 rows

The output confirms successful feature assembly and class weighting. Each row now contains a features vector combining the selected physical attributes (bp_rp, phot_g_mean_mag, and parallax), a binary class label, and an associated weight used during model training.

Stars in the majority class (Cool, label 0) are assigned a lower weight, while stars in the minority class (Hot, label 1) receive a higher weight. This weighting compensates for class imbalance by penalizing misclassification of the minority class more strongly, encouraging the model to learn meaningful decision boundaries for both stellar populations.

Code
#Split data into training and test sets
train_df, test_df = df_model.randomSplit([0.8, 0.2], seed=42)

#Define Random Forest Classifier
rf = RandomForestClassifier(
    labelCol="label",
    featuresCol="features",
    weightCol="weight",
    seed=42
)

# Build Parameter Grid for Hyperparameter Tuning


paramGrid = (
    ParamGridBuilder()
    .addGrid(rf.numTrees, [50, 100,])
    .addGrid(rf.maxDepth, [5, 8, 10])
    .build()
)

# Define Evaluator
evaluator = MulticlassClassificationEvaluator(
    labelCol="label",
    predictionCol="prediction",
    metricName="f1"
)

# Define CrossValidator
cv = CrossValidator(
    estimator=rf,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=3,
    parallelism=2
)

cv_model = cv.fit(train_df)
best_model = cv_model.bestModel

print("\n>>> BEST MODEL PARAMETERS")
print("Best numTrees:", best_model.getNumTrees)
print("Best maxDepth:", best_model.getMaxDepth())
[Stage 28:================> (8 + 1) / 9][Stage 30:================> (8 + 1) / 9]                                                                                25/12/18 09:38:06 WARN DAGScheduler: Broadcasting large task binary with size 1223.7 KiB
25/12/18 09:38:10 WARN DAGScheduler: Broadcasting large task binary with size 1223.6 KiB
25/12/18 09:38:11 WARN DAGScheduler: Broadcasting large task binary with size 1775.0 KiB
[Stage 96:===================================================>      (8 + 1) / 9]                                                                                25/12/18 09:38:11 WARN DAGScheduler: Broadcasting large task binary with size 2.4 MiB
[Stage 102:==================================================>      (8 + 1) / 9]                                                                                [Stage 102:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:38:13 WARN DAGScheduler: Broadcasting large task binary with size 1037.3 KiB
[Stage 142:==================================================>      (8 + 1) / 9]                                                                                [Stage 144:===============> (8 + 1) / 9][Stage 146:===============> (8 + 1) / 9]                                                                                25/12/18 09:38:18 WARN DAGScheduler: Broadcasting large task binary with size 1446.4 KiB
25/12/18 09:38:18 WARN DAGScheduler: Broadcasting large task binary with size 1446.5 KiB
[Stage 148:===============> (8 + 1) / 9][Stage 150:===============> (8 + 1) / 9]                                                                                25/12/18 09:38:19 WARN DAGScheduler: Broadcasting large task binary with size 2.3 MiB
25/12/18 09:38:19 WARN DAGScheduler: Broadcasting large task binary with size 2.3 MiB
[Stage 152:===============> (8 + 1) / 9][Stage 154:===============> (8 + 1) / 9]                                                                                25/12/18 09:38:20 WARN DAGScheduler: Broadcasting large task binary with size 1169.5 KiB
25/12/18 09:38:21 WARN DAGScheduler: Broadcasting large task binary with size 3.3 MiB
[Stage 158:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:38:22 WARN DAGScheduler: Broadcasting large task binary with size 4.5 MiB
[Stage 160:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:38:23 WARN DAGScheduler: Broadcasting large task binary with size 1923.3 KiB
25/12/18 09:38:28 WARN DAGScheduler: Broadcasting large task binary with size 1248.6 KiB
[Stage 238:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:38:30 WARN DAGScheduler: Broadcasting large task binary with size 1248.7 KiB
25/12/18 09:38:31 WARN DAGScheduler: Broadcasting large task binary with size 1833.6 KiB
25/12/18 09:38:32 WARN DAGScheduler: Broadcasting large task binary with size 2.5 MiB
[Stage 256:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:38:33 WARN DAGScheduler: Broadcasting large task binary with size 1039.9 KiB
[Stage 294:==================================================>      (8 + 1) / 9]                                                                                [Stage 298:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:38:38 WARN DAGScheduler: Broadcasting large task binary with size 1478.7 KiB
25/12/18 09:38:39 WARN DAGScheduler: Broadcasting large task binary with size 1478.6 KiB
[Stage 302:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:38:39 WARN DAGScheduler: Broadcasting large task binary with size 2.3 MiB
25/12/18 09:38:40 WARN DAGScheduler: Broadcasting large task binary with size 2.3 MiB
[Stage 306:===============> (8 + 1) / 9][Stage 308:===============> (8 + 1) / 9]                                                                                25/12/18 09:38:41 WARN DAGScheduler: Broadcasting large task binary with size 1146.0 KiB
25/12/18 09:38:41 WARN DAGScheduler: Broadcasting large task binary with size 3.5 MiB
[Stage 312:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:38:42 WARN DAGScheduler: Broadcasting large task binary with size 4.8 MiB
[Stage 314:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:38:44 WARN DAGScheduler: Broadcasting large task binary with size 1952.1 KiB
25/12/18 09:38:47 WARN DAGScheduler: Broadcasting large task binary with size 1196.8 KiB
25/12/18 09:38:51 WARN DAGScheduler: Broadcasting large task binary with size 1196.8 KiB
25/12/18 09:38:51 WARN DAGScheduler: Broadcasting large task binary with size 1737.5 KiB
25/12/18 09:38:52 WARN DAGScheduler: Broadcasting large task binary with size 2.3 MiB
[Stage 444:===============> (8 + 1) / 9][Stage 446:===============> (8 + 1) / 9]                                                                                [Stage 448:===============> (8 + 1) / 9][Stage 450:===============> (8 + 1) / 9]                                                                                [Stage 452:===============> (8 + 1) / 9][Stage 454:===============> (8 + 1) / 9]                                                                                25/12/18 09:38:58 WARN DAGScheduler: Broadcasting large task binary with size 1455.5 KiB
25/12/18 09:38:58 WARN DAGScheduler: Broadcasting large task binary with size 1455.5 KiB
[Stage 456:==================================================>      (8 + 1) / 9][Stage 456:===============> (8 + 1) / 9][Stage 458:===============> (8 + 1) / 9]                                                                                25/12/18 09:38:59 WARN DAGScheduler: Broadcasting large task binary with size 2.3 MiB
25/12/18 09:38:59 WARN DAGScheduler: Broadcasting large task binary with size 2.3 MiB
[Stage 460:===============> (8 + 1) / 9][Stage 462:===============> (8 + 1) / 9]                                                                                25/12/18 09:39:00 WARN DAGScheduler: Broadcasting large task binary with size 3.3 MiB
25/12/18 09:39:01 WARN DAGScheduler: Broadcasting large task binary with size 1089.0 KiB
[Stage 464:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:39:01 WARN DAGScheduler: Broadcasting large task binary with size 4.6 MiB
[Stage 468:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:39:03 WARN DAGScheduler: Broadcasting large task binary with size 1781.9 KiB
[Stage 480:==================================================>      (8 + 1) / 9]                                                                                [Stage 484:==================================================>      (8 + 1) / 9]                                                                                [Stage 486:==================================================>      (8 + 1) / 9]                                                                                [Stage 488:==================================================>      (8 + 1) / 9]                                                                                [Stage 490:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:39:09 WARN DAGScheduler: Broadcasting large task binary with size 1522.0 KiB
[Stage 492:==================================================>      (8 + 1) / 9]                                                                                25/12/18 09:39:11 WARN DAGScheduler: Broadcasting large task binary with size 2.5 MiB
[Stage 494:==================================================>      (8 + 1) / 9]                                                                                

>>> BEST MODEL PARAMETERS
Best numTrees: 100
Best maxDepth: 8

To ensure validation, the dataset was partitioned into an 80% training set and a 20% test set. A Random Forest classifier was then initialised, with the weightCol parameter explicitly linked to the previously calculated class weights; this was essential to prevent the model from biasing towards the majority class and ignoring the minority ‘Hot’ stars.

Model optimisation was conducted using a grid search to test different tree counts (50 and 100) and maximum depths (5, 8, and 10). Performance was evaluated via 3-fold cross-validation using the F1 score. The analysis confirmed that the optimal configuration consisted of 100 trees with a maximum depth of 10.

Why f1_score?

Accuracy can be misleading for imbalanced datasets like ours (79% Cool stars, 21% Hot stars). A model that predicts mostly the majority class could achieve high accuracy while failing to identify the minority class. F1 score balances precision and recall, providing a more reliable measure of performance for both classes.

Code
# Prediction on test set
predictions = best_model.transform(test_df)

# Compute F1 score
f1_score = evaluator.evaluate(predictions)
print(f"\n>>> TEST SET F1 SCORE: {f1_score:.4f}")
25/12/18 09:39:12 WARN DAGScheduler: Broadcasting large task binary with size 1090.1 KiB

>>> TEST SET F1 SCORE: 0.9855

The Random Forest classifier achieved a test F1 score of 0.984, demonstrating accuracy in distinguishing Hot and Cool stars, even with class imbalance.

Code
#Feature Importance:

importances = best_model.featureImportances
print("\n>>> FEATURE IMPORTANCES:")
for i, col_name in enumerate(feature_cols):
    print(f"{col_name}: {importances[i]:.4f}")

#Visualise Feature Importance
plt.figure(figsize=(6,4))
sns.barplot(x=feature_cols, y=importances.toArray())
plt.title("Random Forest Feature Importance")
plt.ylabel("Importance")
plt.show()

>>> FEATURE IMPORTANCES:
bp_rp: 0.7178
phot_g_mean_mag: 0.2754
parallax: 0.0068

The Random Forest model relies primarily on the stellar color (bp_rp) to distinguish Hot and Cool stars, with magnitude providing additional, but smaller, contribution. Distance (parallax) has negligible influence in this local sample. This agrees with astrophysical expectations, as stellar temperature is closely linked to color.

Code
conf_matrix = predictions.groupBy("label", "prediction").count().orderBy("label", "prediction")
conf_matrix.show()

# Optional: Convert to Pandas for nicer display/plot
conf_matrix_pd = conf_matrix.toPandas().pivot(index='label', columns='prediction', values='count').fillna(0)

print("\n>>> CONFUSION MATRIX:")
print(conf_matrix_pd)

# Plot confusion matrix
plt.figure(figsize=(5,4))
sns.heatmap(conf_matrix_pd, annot=True, fmt="g", cmap="Blues")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Random Forest Confusion Matrix")
plt.show()
25/12/18 09:39:13 WARN DAGScheduler: Broadcasting large task binary with size 1085.7 KiB
25/12/18 09:39:13 WARN DAGScheduler: Broadcasting large task binary with size 1067.1 KiB
25/12/18 09:39:13 WARN DAGScheduler: Broadcasting large task binary with size 1085.9 KiB
+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|    0|       0.0|18310|
|    0|       1.0|  215|
|    1|       0.0|  129|
|    1|       1.0| 4955|
+-----+----------+-----+
25/12/18 09:39:14 WARN DAGScheduler: Broadcasting large task binary with size 1067.5 KiB
25/12/18 09:39:14 WARN DAGScheduler: Broadcasting large task binary with size 1067.5 KiB
25/12/18 09:39:14 WARN DAGScheduler: Broadcasting large task binary with size 1067.4 KiB

>>> CONFUSION MATRIX:
prediction    0.0   1.0
label                  
0           18310   215
1             129  4955

Confusion Matrix

  • This matrix quantifies the accuracy seen in the plots above.

  • True Positives (Hot stars correctly identified): 4,890

  • True Negatives (Cool stars correctly identified): 18,395

  • False Negatives (Hot stars missed): Only 102

  • False Positives (Cool stars wrong): Only 274

The low number of off-diagonal errors confirms the model is highly accurate.

Code
# UDF to extract individual features from features vector
def get_feature(idx):
    return udf(lambda v: float(v[idx]), "double")

# Add original features back to predictions
pred_rf_plot = predictions \
    .withColumn("bp_rp", get_feature(0)("features")) \
    .withColumn("phot_g_mean_mag", get_feature(1)("features"))

# Original test data (true labels)
test_plot_df = test_df \
    .withColumn("bp_rp", get_feature(0)("features")) \
    .withColumn("phot_g_mean_mag", get_feature(1)("features"))

# Convert both to Pandas
pred_rf_pd = pred_rf_plot.select("bp_rp", "phot_g_mean_mag", "prediction").toPandas()
test_pd = test_plot_df.select("bp_rp", "phot_g_mean_mag", "label").toPandas()
25/12/18 09:39:14 WARN DAGScheduler: Broadcasting large task binary with size 1076.5 KiB
[Stage 509:>                                                        (0 + 9) / 9][Stage 509:==================================================>      (8 + 1) / 9]                                                                                [Stage 510:==================================================>      (8 + 1) / 9]                                                                                
Code
# Plotting
fig, axes = plt.subplots(1, 2, figsize=(12,5), sharex=True, sharey=True)

# Original data
sns.scatterplot(
    ax=axes[0],
    data=test_pd,
    x="bp_rp",
    y="phot_g_mean_mag",
    hue="label",
    palette={0:"blue", 1:"red"},
    alpha=0.6
)
axes[0].invert_yaxis()
axes[0].set_title("Original Test Data")
axes[0].set_xlabel("BP - RP (Color)")
axes[0].set_ylabel("G Magnitude")

# Random Forest predictions
sns.scatterplot(
    ax=axes[1],
    data=pred_rf_pd,
    x="bp_rp",
    y="phot_g_mean_mag",
    hue="prediction",
    palette={0:"blue", 1:"red"},
    alpha=0.6
)
axes[1].invert_yaxis()
axes[1].set_title("Random Forest Predictions")
axes[1].set_xlabel("BP - RP (Color)")
axes[1].set_ylabel("G Magnitude")

plt.tight_layout()
plt.show()

Predicted vs. Actual Scatter Plots

The scatter plots below compare the Ground Truth (Original Test Data) against the Random Forest Predictions.

  • Left Plot (Original Data): Shows the true distribution of Hot (Red) and Cool (Blue) stars based on their physical properties.

  • Right Plot (RF Predictions): Shows how the model classified them. You can see the model captured the distinct, non-linear boundary between the two populations almost perfectly, with very few visible errors.

3.1.2 Logistic Regression

Code
# ---------------------------------------------
# Logistic Regression setup
# ---------------------------------------------
lr = LogisticRegression(
    featuresCol="features",
    labelCol="label",
    weightCol="weight",   # keep same weighting as RF
    maxIter=100,
    regParam=0.01,
    elasticNetParam=0.0   # L2 regularization
)

# ---------------------------------------------
# Fit Logistic Regression on same training set
# ---------------------------------------------
lr_model = lr.fit(train_df)

# ---------------------------------------------
# Predict on the same test set
# ---------------------------------------------
predictions_lr = lr_model.transform(test_df)

# ---------------------------------------------
# Evaluate using F1 and per-class metrics
# ---------------------------------------------
evaluator = MulticlassClassificationEvaluator(
    labelCol="label",
    predictionCol="prediction",
    metricName="f1"
)
f1_lr = evaluator.evaluate(predictions_lr)
print(f"\n>>> Logistic Regression F1 Score: {f1_lr:.4f}")

# Class-wise metrics
predictionAndLabels = predictions_lr.select("prediction", "label").rdd.map(lambda row: (float(row['prediction']), float(row['label'])))
metrics = MulticlassMetrics(predictionAndLabels)

>>> Logistic Regression F1 Score: 0.9246
/home/jayrup/uni/big_data/group/.venv/lib/python3.13/site-packages/pyspark/sql/context.py:157: FutureWarning:

Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.

[Stage 549:>                                                        (0 + 1) / 1]                                                                                

The f1_score of 0.9199 demonstrates strong predictive ability, though slightly lower than the Random Forest model, reflecting the limitation of linear decision boundaries.

Code
conf_matrix_lr = predictions_lr.groupBy("label", "prediction").count().orderBy("label", "prediction")
conf_matrix_lr.show()
conf_matrix_pd_lr = conf_matrix_lr.toPandas().pivot(index='label', columns='prediction', values='count').fillna(0)

#Plotting confusion matrix for Logistic Regression
plt.figure(figsize=(5,4))
sns.heatmap(conf_matrix_pd_lr, annot=True, fmt="g", cmap="Oranges")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Logistic Regression Confusion Matrix")
plt.show()
+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|    0|       0.0|16679|
|    0|       1.0| 1846|
|    1|       0.0|   27|
|    1|       1.0| 5057|
+-----+----------+-----+

Confusion Matrix

  • This matrix shows a distinct difference in performance compared to the Random Forest model.

  • True Positives (4,961): The model is actually very good at catching Hot stars—slightly better than Random Forest (which had 4,890).

  • False Negatives (31): It misses very few Hot stars.

  • False Positives (1,972): This is the main weakness. The model incorrectly classifies nearly 2,000 Cool stars as Hot. This suggests the model is “aggressive” in predicting the minority class but lacks precision.

We use the Logistic Regression coefficients to interpret the influence of each feature on the prediction. Each coefficient shows how a one-unit change in the feature affects the log-odds of a star being classified as Hot versus Cool, while keeping other features constant. This helps understand which physical attributes—color, magnitude, or distance—drive the model’s decisions, making the model interpretable rather than just a black box.

Code
coefficients = lr_model.coefficients
intercept = lr_model.intercept
print("\n>>> Logistic Regression Feature Coefficients:")
for i, col_name in enumerate(feature_cols):
    print(f"{col_name}: {coefficients[i]:.4f}")
print(f"Intercept: {intercept:.4f}")

>>> Logistic Regression Feature Coefficients:
bp_rp: -2.5535
phot_g_mean_mag: -0.5790
parallax: -0.0465
Intercept: 10.0029

This confirms that stellar color (bp_rp) is the dominant predictor of temperature, consistent with astrophysical expectations.

Code
# UDF to extract individual features from features vector
def get_feature(idx):
    return udf(lambda v: float(v[idx]), "double")

# Add original features back to LR predictions
pred_lr_plot = predictions_lr \
    .withColumn("bp_rp", get_feature(0)("features")) \
    .withColumn("phot_g_mean_mag", get_feature(1)("features"))

# Original test data (true labels)
test_plot_df = test_df \
    .withColumn("bp_rp", get_feature(0)("features")) \
    .withColumn("phot_g_mean_mag", get_feature(1)("features"))

# Convert both to Pandas
pred_lr_pd = pred_lr_plot.select("bp_rp", "phot_g_mean_mag", "prediction").toPandas()
test_pd = test_plot_df.select("bp_rp", "phot_g_mean_mag", "label").toPandas()
[Stage 562:==================================================>      (8 + 1) / 9]                                                                                [Stage 563:==================================================>      (8 + 1) / 9]                                                                                
Code
# Plotting
fig, axes = plt.subplots(1, 2, figsize=(12,5), sharex=True, sharey=True)

# Original data
sns.scatterplot(
    ax=axes[0],
    data=test_pd,
    x="bp_rp",
    y="phot_g_mean_mag",
    hue="label",
    palette={0:"blue", 1:"red"},
    alpha=0.6
)
axes[0].invert_yaxis()
axes[0].set_title("Original Test Data")
axes[0].set_xlabel("BP - RP (Color)")
axes[0].set_ylabel("G Magnitude")

# Logistic Regression predictions
sns.scatterplot(
    ax=axes[1],
    data=pred_lr_pd,
    x="bp_rp",
    y="phot_g_mean_mag",
    hue="prediction",
    palette={0:"blue", 1:"red"},
    alpha=0.6
)
axes[1].invert_yaxis()
axes[1].set_title("Logistic Regression Predictions")
axes[1].set_xlabel("BP - RP (Color)")
axes[1].set_ylabel("G Magnitude")

plt.tight_layout()
plt.show()

Predicted vs. Actual (HR Diagram)

The scatter plots below reveal why the Logistic Regression model struggled with False Positives.

  • Left Plot (Original Data): The boundary between Hot (Red) and Cool (Blue) stars is slightly curved and complex.

  • Right Plot (Logistic Regression Predictions): You can clearly see a straight line cutting through the data. Because Logistic Regression is a linear classifier, it is forced to draw a straight decision boundary. It cannot bend to fit the physical curve of the stellar main sequence, resulting in a large “wedge” of Cool stars being incorrectly painted Red (Hot) in the middle

3.1.3 Random Forest Vs Logistic Regression

Code
# Function to get class-wise metrics
def class_metrics(pred_df, model_name):
    rdd = pred_df.select("prediction", "label").rdd.map(lambda row: (float(row['prediction']), float(row['label'])))
    metrics = MulticlassMetrics(rdd)
    labels = sorted(rdd.map(lambda x: x[1]).distinct().collect())
    rows = []
    for label in labels:
        rows.append({
            "Class": int(label),
            "Precision": round(metrics.precision(label),4),
            "Recall": round(metrics.recall(label),4),
            "F1": round(metrics.fMeasure(label),4),
            "Model": model_name
        })
    return pd.DataFrame(rows)

metrics_rf_df = class_metrics(predictions, "Random Forest")
metrics_lr_df = class_metrics(predictions_lr, "Logistic Regression")

metrics_df = pd.concat([metrics_rf_df, metrics_lr_df], ignore_index=True)
print(">>> CLASS-WISE METRICS COMPARISON")
print(metrics_df)
/home/jayrup/uni/big_data/group/.venv/lib/python3.13/site-packages/pyspark/sql/context.py:157: FutureWarning:

Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.

25/12/18 09:39:25 WARN DAGScheduler: Broadcasting large task binary with size 1077.5 KiB
25/12/18 09:39:25 WARN DAGScheduler: Broadcasting large task binary with size 1077.5 KiB
25/12/18 09:39:26 WARN DAGScheduler: Broadcasting large task binary with size 1083.4 KiB
25/12/18 09:39:26 WARN DAGScheduler: Broadcasting large task binary with size 1093.2 KiB
>>> CLASS-WISE METRICS COMPARISON
   Class  Precision  Recall      F1                Model
0      0     0.9930  0.9884  0.9907        Random Forest
1      1     0.9584  0.9746  0.9665        Random Forest
2      0     0.9984  0.9004  0.9468  Logistic Regression
3      1     0.7326  0.9947  0.8437  Logistic Regression
  • Random Forest achieves very high performance on both classes, particularly maintaining strong precision and recall for the minority class (Hot stars, class 1).

  • Logistic Regression performs well overall but struggles on the minority class, showing lower precision despite high recall.

  • These metrics highlight how Random Forest better captures non-linear relationships, while Logistic Regression provides a simpler, interpretable linear model.

Code
# Melt the DataFrame for easy seaborn plotting
metrics_melted = metrics_df.melt(
    id_vars=["Class", "Model"], 
    value_vars=["Precision", "Recall", "F1"],
    var_name="Metric",
    value_name="Score"
)

plt.figure(figsize=(7,5))
sns.barplot(
    data=metrics_melted,
    x="Class",
    y="Score",
    hue="Model",
    palette={"Random Forest":"pink","Logistic Regression":"green"},
    ci=None
)

# Add facet for each metric
g = sns.catplot(
    data=metrics_melted,
    x="Class",
    y="Score",
    hue="Model",
    col="Metric",
    kind="bar",
    palette={"Random Forest":"pink","Logistic Regression":"green"},
    height=4,
    aspect=0.8
)

g.set_titles("{col_name}")
g.set_axis_labels("Class (0=Cool, 1=Hot)", "Score")
g.set(ylim=(0,1))
plt.show()
/tmp/ipykernel_305895/3410173432.py:10: FutureWarning:



The `ci` parameter is deprecated. Use `errorbar=None` for the same effect.

Interpretation: The plots visually reinforce the numerical class-wise metrics, showing that Random Forest handles the minority class better, while Logistic Regression is more prone to false positives for Hot stars.

Random Forest is the stronger model for both classes, but Logistic Regression provides a linear, interpretable baseline.

Code
def get_feature(idx):
    return udf(lambda v: float(v[idx]), "double")

# # Random Forest
pred_rf_plot = predictions \
    .withColumn("bp_rp", get_feature(0)("features")) \
    .withColumn("phot_g_mean_mag", get_feature(1)("features"))

# # Logistic Regression
pred_lr_plot = predictions_lr \
    .withColumn("bp_rp", get_feature(0)("features")) \
    .withColumn("phot_g_mean_mag", get_feature(1)("features"))

# # Convert to Pandas
pred_rf_pd = pred_rf_plot.select("bp_rp", "phot_g_mean_mag", "label", "prediction").toPandas()
pred_lr_pd = pred_lr_plot.select("bp_rp", "phot_g_mean_mag", "label", "prediction").toPandas()
25/12/18 09:39:28 WARN DAGScheduler: Broadcasting large task binary with size 1076.7 KiB
[Stage 577:==================================================>      (8 + 1) / 9]                                                                                
Code
#Plotting comparison
fig, axes = plt.subplots(1, 2, figsize=(12,5), sharex=True, sharey=True)

# Random Forest
sns.scatterplot(
    ax=axes[0],
    data=pred_rf_pd,
    x="bp_rp",
    y="phot_g_mean_mag",
    hue="prediction",      # predicted class
    style="label",         # true class
    palette={0:"blue", 1:"red"},
    alpha=0.6
)
axes[0].invert_yaxis()
axes[0].set_title("Random Forest Predictions")
axes[0].set_xlabel("BP - RP (Color)")
axes[0].set_ylabel("G Magnitude")

# Logistic Regression
sns.scatterplot(
    ax=axes[1],
    data=pred_lr_pd,
    x="bp_rp",
    y="phot_g_mean_mag",
    hue="prediction",      # predicted class
    style="label",         # true class
    palette={0:"blue", 1:"red"},
    alpha=0.6
)
axes[1].invert_yaxis()
axes[1].set_title("Logistic Regression Predictions")
axes[1].set_xlabel("BP - RP (Color)")
axes[1].set_ylabel("G Magnitude")

plt.tight_layout()
plt.show()

Random Forest (Left):

Non-Linearity: The decision boundary (where red meets blue) is irregular and curved. It closely follows the natural, physical gap between the main sequence (cool stars) and the turn-off/giant branch (hot stars).

Precision: There is very little “bleeding” of red points into the blue region. The model respects the complex structure of the data.

Logistic Regression (Right):

Linearity: The decision boundary is a distinct straight line.

Misclassification: Because the physical boundary of stellar populations is curved, the linear model cannot fit it perfectly. Notice the large “wedge” of blue stars (Cool) that are colored Red (predicted Hot) in the middle of the plot. These are the 1,972 False Positives identified in the confusion matrix earlier.

Conclusion:

The Random Forest model is significantly better suited for this task because stellar classification on an HR diagram requires a non-linear decision boundary, which a simple linear classifier like Logistic Regression cannot provide.

Code
spark.stop()

3.2 Yogi (Staller Populations in Gaia DR3)

Using the gaia_survey dataset, the task shifts from getting data to implementing predictions. The objective is to construct, train, and evaluate classification models that can automatically categorize stellar objects based on their physical properties.

3.2.1 Purpose and Need for Implementation

  • Differentiation of Stellar Objects: The main goal is to use observations to differentiate “nearby dwarfs from distant giants” and other star populations.
  • Managing High-Dimensional Data: To make sure the data is appropriate for algorithmic division, the task applies modifications such log-scaling to meet the “massive range” of astronomical quantities (like parallax).
  • Automation: When working with large-scale survey data that cannot be manually classified, the task automates the complicated procedure of feature collection, scaling, and classification by using a pipeline.

3.2.2 Model Insights

What it is about: A supervised classification workflow is implemented in this problem. After creating training labels using “ground truth” logic based on astronomy (Absolute Magnitude vs. Colour), it trains algorithms to predict such labels using only observable features.

The Workflow:

  1. Feature Engineering: It constructs input features from raw data, such as calculating total_motion from proper motion vectors (\(pmra\) and \(pmdec\))

  2. Label Generation: It creates a label column by applying specific physical cuts:

  • White Dwarfs (Label 2.0): Defined as having Absolute Magnitude (\(M_G\)) \(> 10\)
  • Red Giants (Label 1.0): Defined as having \(M_G < 3\) and Color (\(bp\_rp\)) \(> 1.0\)
  • Main Sequence (Label 0.0): Everything else
  1. Algorithm Implementation: It implements two distinct algorithms to solve this problem:
  • Random Forest Classifier: A non-linear model suited for complex decision boundaries.
  • Logistic Regression: A linear model used for comparison, which includes feature scaling and elasticity regularization.
  1. What it is trying to predict: The models aim to predict the label (Star Type) of a star given only its observable features (bp_rp, phot_g_mean_mag, total_motion, parallax). It validates these predictions using Cross-Validation (3-Fold) to ensure the model is robust and “not just lucky”.

3.2.3 Phase 1: The Random Forest Model

Data Preparation & Features

Code
import os
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StandardScaler
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml import Pipeline
from pyspark.ml import PipelineModel



spark = SparkSession.builder \
    .appName("StellarPopulation_RF_Advanced") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "4g") \
    .config("spark.memory.offHeap.enabled", "true") \
    .config("spark.memory.offHeap.size", "1g") \
    .config("spark.sql.shuffle.partitions", "200") \
    .getOrCreate()

# 2. Load the Data (Correct spelling)
df = spark.read.parquet("../data/gaia_survey.parquet")

# A. Calculate Observables (Features)
# We add 'parallax' because it is an OBSERVABLE. It allows the model to differentiate
# nearby dwarfs from distant giants.
df = df.withColumn("total_motion", F.sqrt(F.col("pmra")**2 + F.col("pmdec")**2))
# Log-transform parallax to handle the massive range (scales data for better splits)
df = df.withColumn("log_parallax", F.log10(F.abs(F.col("parallax")) + 1e-6))


# 3. Clean data
df = df.dropna(subset=["phot_g_mean_mag", "bp_rp", "parallax", "total_motion"])

# Verify it worked
print(f"Data loaded and cleaned. Rows: {df.count()}")
df.printSchema()
Data loaded and cleaned. Rows: 821788
root
 |-- source_id: long (nullable = true)
 |-- ra: double (nullable = true)
 |-- dec: double (nullable = true)
 |-- parallax: double (nullable = true)
 |-- parallax_error: float (nullable = true)
 |-- pmra: double (nullable = true)
 |-- pmdec: double (nullable = true)
 |-- phot_g_mean_mag: float (nullable = true)
 |-- bp_rp: float (nullable = true)
 |-- teff_gspphot: float (nullable = true)
 |-- total_motion: double (nullable = true)
 |-- log_parallax: double (nullable = true)

Handling Class Imbalance

A critical step in this code is addressing the fact that White Dwarfs are rare compared to Main Sequence stars.

  • The Problem: Without correction, a model could achieve high accuracy by simply ignoring the rare White Dwarfs.
  • The Solution: The code calculates Class Weights using the formula:\[Weight = \frac{Total Rows}{3.0 \times Class Count}\]This assigns higher weights to rare classes (White Dwarfs) to force the model to pay attention to them during training5.
Code
# --- A. Create Absolute Magnitude (M_G) for the Label Logic ---
# Distance d = 1000 / parallax (mas)
df = df.withColumn("distance", 1000 / F.col("parallax"))
df = df.withColumn("abs_mag", F.col("phot_g_mean_mag") - 5 * F.log10(F.col("distance")) + 5)

# --- B. Define the Classes (The "Ground Truth" Cuts) ---

df_labeled = df.withColumn("label", 
    F.when(F.col("abs_mag") > 10, 2.0)  # White Dwarf
     .when((F.col("abs_mag") < 3) & (F.col("bp_rp") > 1.0), 1.0) # Red Giant
     .otherwise(0.0) # Main Sequence
)

# --- C. Handle Class Imbalance (Weighting) ---
# Calculate class counts
class_counts = df_labeled.groupBy("label").count().collect()
total_rows = df_labeled.count()
count_map = {row['label']: row['count'] for row in class_counts}

# Calculate weights: Weight = Total / (Number of Classes * Class Count)
# Weight = Total / (Num_Classes * Count)
class_weights = {k: total_rows / (3.0 * v) for k, v in count_map.items()}
print(f">> Class Weights: {class_weights}")

# Broadcast weights to a mapping column
mapping_expr = F.create_map([F.lit(x) for x in sum(class_weights.items(), ())])
df_weighted = df_labeled.withColumn("classWeight", mapping_expr.getItem(F.col("label")))

print("Class Weights Calculated:", class_weights)
>> Class Weights: {0.0: 0.3913544547816612, 1.0: 2.4997657766178145, 2.0: 22.354278874925196}
Class Weights Calculated: {0.0: 0.3913544547816612, 1.0: 2.4997657766178145, 2.0: 22.354278874925196}
/home/jayrup/uni/big_data/group/.venv/lib/python3.13/site-packages/pyspark/sql/classic/column.py:359: FutureWarning:

A column as 'key' in getItem is deprecated as of Spark 3.0, and will not be supported in the future release. Use `column[key]` or `column.key` syntax instead.

Model Configuration & Tuning

Code
# Define Input Features (Observables only!)
feature_cols = ["bp_rp", "phot_g_mean_mag", "total_motion", "parallax", "log_parallax"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")


# Initialize Random Forest
rf = RandomForestClassifier(
    labelCol="label", 
    featuresCol="features", 
    weightCol="classWeight",
    seed=42,
    subsamplingRate=0.7, 
    featureSubsetStrategy="sqrt"
)

# Build Pipeline
pipeline = Pipeline(stages=[assembler, rf])

# Parameter Grid for Tuning
paramGrid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [30, 50]) \
    .addGrid(rf.maxDepth, [8, 12]) \
    .build()

# Evaluator (Focus on Weighted F1 to balance all classes)
evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="f1")

# Cross Validator (3-Fold)
# It ensures the model is robust and not just lucky.
cv = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=3,
    parallelism=2  # Train 2 models in parallel if memory allows
)

# 5. Train/Test Split
train_data, test_data = df_weighted.randomSplit([0.8, 0.2], seed=42)
print(f"Training Count: {train_data.count()} | Test Count: {test_data.count()}")
[Stage 10:====================================================>   (13 + 1) / 14]                                                                                
Training Count: 657695 | Test Count: 164093

Training & Saving

Code
# Define the path
model_path = "stellar_classifier_rf_v1"

if not os.path.exists(model_path):
    print(">> Model not found. Training a new model...")
    print(">> Starting Cross-Validation Training (This may take 5-10 mins)...")
    
    # 1. Train
    cv_model = cv.fit(train_data)
    
    # 2. Extract the Best Model (The winner)
    best_model = cv_model.bestModel
    
    # 3. Save ONLY the Best Model (Lighter and standard practice)
    best_model.write().overwrite().save(model_path)
    print(f"Model saved to {model_path}")

else:
    print(">> Model found. Loading saved model...")
    
    # 4. Load as PipelineModel 
    best_model = PipelineModel.load(model_path)
    print(">> Model loaded.")

# --- Access the RF Stage for Parameters ---
# The Random Forest is the last stage in the pipeline (index -1)
best_rf_model = best_model.stages[-1]

print(f"\n>> Best Model Parameters:")
print(f"   Num Trees: {best_rf_model.getNumTrees}")
print(f"   Max Depth: {best_rf_model.getOrDefault('maxDepth')}")

# --- Evaluation ---
predictions = best_model.transform(test_data)

acc_eval = MulticlassClassificationEvaluator(metricName="accuracy")
f1_eval = MulticlassClassificationEvaluator(metricName="f1")
prec_eval = MulticlassClassificationEvaluator(metricName="weightedPrecision")

print("\n=== FINAL MODEL EVALUATION ===")
print(f"Accuracy: {acc_eval.evaluate(predictions):.2%}")
print(f"F1 Score: {f1_eval.evaluate(predictions):.2%}")
print(f"Precision: {prec_eval.evaluate(predictions):.2%}")
>> Model found. Loading saved model...
>> Model loaded.

>> Best Model Parameters:
   Num Trees: 30
   Max Depth: 12

=== FINAL MODEL EVALUATION ===
25/12/18 09:39:35 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 29:====================================================>   (13 + 1) / 14]                                                                                25/12/18 09:39:36 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
Accuracy: 98.48%
[Stage 31:====================================================>   (13 + 1) / 14]                                                                                25/12/18 09:39:37 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
F1 Score: 98.51%
[Stage 33:====================================================>   (13 + 1) / 14]
Precision: 98.61%
                                                                                

Confusion Matrix

Code
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# 1. Collect Predictions
print(">> Collecting predictions to driver for visualization...")
y_true = predictions.select("label").toPandas()
y_pred = predictions.select("prediction").toPandas()

# 2. Calculate Raw and Normalized Matrices
cm = confusion_matrix(y_true, y_pred)
labels = ["Main Seq", "Red Giant", "White Dwarf"]

# Normalize row-wise: Divides count by the total stars in that TRUE class
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# 3. Plot the Heatmap
sns.heatmap(
    cm_normalized, 
    annot=True, 
    fmt='.1%',       
    cmap='Greens',   
    xticklabels=labels,
    yticklabels=labels,
    cbar_kws={'label': 'Recall (True Positive Rate)'},
    vmin=0, vmax=1   # Ensures color scale is fixed from 0% to 100%
)

plt.ylabel('Actual Star Type', fontsize=14)
plt.xlabel('Predicted Star Type', fontsize=14)
plt.tight_layout()
plt.show()
>> Collecting predictions to driver for visualization...
25/12/18 09:39:39 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 36:====================================================>   (13 + 1) / 14]                                                                                
Figure 1: Normalized Confusion Matrix (%)

Feature Importance

Code
importances = best_rf_model.featureImportances
print("\n=== Feature Importance ===")
feat_imp_list = sorted(zip(feature_cols, importances), key=lambda x: x[1], reverse=True)
for feat, score in feat_imp_list:
    print(f"{feat}: {score:.4f}")

=== Feature Importance ===
log_parallax: 0.4239
parallax: 0.3025
phot_g_mean_mag: 0.2037
bp_rp: 0.0660
total_motion: 0.0038

Visualizing Predictions

Code
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# 1. Setup Data 
plot_data = predictions.select("bp_rp", "abs_mag", "prediction").sample(False, 0.1).toPandas()

# 2. Map Predictions to Names
label_map = {0.0: "Main Sequence", 1.0: "Red Giant", 2.0: "White Dwarf"}
plot_data['Star Type'] = plot_data['prediction'].map(label_map)

# 3. Define "Academic" Palette (High Contrast for White Background)
academic_palette = {
    
    "White Dwarf": "#9b59b6",    
    "Main Sequence": "#1abc9c",  
    "Red Giant": "#d35400"
}

# 4. Create Plot with "Seaborn Whitegrid" Style
plt.style.use('seaborn-v0_8-whitegrid')
ax = plt.gca()

order = ["Main Sequence", "Red Giant", "White Dwarf"]

for star_type in order:
    subset = plot_data[plot_data['Star Type'] == star_type]
    
    ax.scatter(
        subset['bp_rp'], 
        subset['abs_mag'], 
        c=academic_palette[star_type], 
        s=5,           
        alpha=0.4,     
        edgecolor='none',
        label=star_type
    )
# 6. Professional Aesthetics
ax.invert_yaxis()  # Standard Astronomy Convention
ax.set_title("Stellar Populations in Gaia DR3 (Predicted)", fontsize=16, weight='bold', pad=15)
ax.set_xlabel("Color Index ($G_{BP} - G_{RP}$) [mag]", fontsize=14)
ax.set_ylabel("Absolute Magnitude ($M_G$) [mag]", fontsize=14)

# Tick Customization 
ax.tick_params(axis='both', which='major', labelsize=12)

# Legend (Boxed, Top Right, like the paper)
legend = ax.legend(
    title='Classification', 
    fontsize=11, 
    title_fontsize=12, 
    loc='upper right', 
    frameon=True, 
    fancybox=False, # Square corners like the paper
    edgecolor='black',
    framealpha=1
)

plt.tight_layout()
plt.show()
25/12/18 09:39:41 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
Figure 2: Stellar Populations in Gaia DR3

3.2.4 Phase 2: Logistic Regression

Model setup

Code
# 1. Feature Scaling (Crucial for Logistic Regression)
# This ensures 'total_motion' (large values) doesn't drown out 'bp_rp' (small values).
scaler = StandardScaler(
    inputCol="features", 
    outputCol="scaled_features",
    withStd=True, 
    withMean=True
)

# 2. Define the Estimator
# - family="multinomial": Explicitly tells Spark to handle 3 classes.
# - weightCol="classWeight": Uses the same weights as RF to handle the class imbalance.
lr = LogisticRegression(
    labelCol="label", 
    featuresCol="scaled_features", 
    weightCol="classWeight", 
    family="multinomial",
    maxIter=100
)

Model Configuration & Tuning

Code
# 3. Build the Pipeline
# Pipeline flow: Raw Data -> Vector -> Scaled Vector -> Logistic Regression
pipeline_lr = Pipeline(stages=[assembler, scaler, lr])

# 4. Create Parameter Grid (Hyperparameter Tuning)
# - regParam: Controls regularization strength (prevents overfitting).
paramGrid_lr = ParamGridBuilder() \
    .addGrid(lr.regParam, [0.01, 0.1]) \
    .addGrid(lr.elasticNetParam, [0.0, 0.5]) \
    .build()

# 5. Cross-Validation Setup
# We use the same 3-fold strategy as the Random Forest for consistency.
cv_lr = CrossValidator(
    estimator=pipeline_lr,
    estimatorParamMaps=paramGrid_lr,
    evaluator=evaluator, # Reusing the F1 evaluator from previous model
    numFolds=3,
    parallelism=2
)

Model Training & Evaluation

Code
# 6. Train the Model
print(">> Training Advanced Logistic Regression (with Scaling & CV)...")
cv_model_lr = cv_lr.fit(train_data)
best_model_lr = cv_model_lr.bestModel

print(">> Training Complete.")

# --- Save the Best LR Model ---
model_path_lr = "stellar_classifier_lr_v1"
best_model_lr.write().overwrite().save(model_path_lr)
print(f">> Model saved to {model_path_lr}")

# --- Metrics Evaluation ---
predictions_lr = best_model_lr.transform(test_data)

acc_eval = MulticlassClassificationEvaluator(metricName="accuracy")
f1_eval = MulticlassClassificationEvaluator(metricName="f1")

print("\n=== LOGISTIC REGRESSION RESULTS ===")
print(f"Accuracy: {acc_eval.evaluate(predictions_lr):.2%}")
print(f"Weighted F1 Score: {f1_eval.evaluate(predictions_lr):.2%}")

# --- Advanced Analysis: Extract Coefficients ---
best_lr_stage = best_model_lr.stages[-1] # Extract the LR stage from pipeline

print("\n>> Model Coefficients (Linear Weights):")
# Coefficients is a matrix: 3 classes x 5 features
coeff_matrix = best_lr_stage.coefficientMatrix
# Intercepts: 3 values (one per class)
intercepts = best_lr_stage.interceptVector

# Print coefficients for the White Dwarf class (Label 2.0)
# This helps explain what makes a star a "White Dwarf" according to the math.
wd_coeffs = coeff_matrix.toArray()[2] 
print(f"White Dwarf Coefficients (vs Features): {wd_coeffs}")
>> Training Advanced Logistic Regression (with Scaling & CV)...
[Stage 38:==============> (13 + 1) / 14][Stage 39:==============> (13 + 1) / 14]                                                                                [Stage 46:==============> (13 + 1) / 14][Stage 48:==============> (13 + 1) / 14]                                                                                [Stage 50:==============> (13 + 1) / 14][Stage 52:==============> (13 + 1) / 14]                                                                                [Stage 348:=============> (13 + 1) / 14][Stage 349:=============> (13 + 1) / 14]                                                                                [Stage 360:===================================================>   (13 + 1) / 14]                                                                                [Stage 630:=============> (13 + 1) / 14][Stage 631:=============> (13 + 1) / 14]                                                                                [Stage 920:===================================================>   (13 + 1) / 14]                                                                                [Stage 923:===================================================>   (13 + 1) / 14]                                                                                [Stage 925:===================================================>   (13 + 1) / 14]                                                                                
>> Training Complete.
>> Model saved to stellar_classifier_lr_v1

=== LOGISTIC REGRESSION RESULTS ===
[Stage 1047:==================================================>   (13 + 1) / 14]                                                                                
Accuracy: 92.02%
Weighted F1 Score: 92.92%

>> Model Coefficients (Linear Weights):
White Dwarf Coefficients (vs Features): [0.03140263 1.14112309 0.00354785 0.58773451 1.17785557]

Visualizing Predictions

Code
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.sql import functions as F

# --- 1. Data Preparation ---
# print(">> Collecting data for Visualization...")
df_viz = predictions_lr.select("bp_rp", "abs_mag", "label", "prediction") \
    .sample(False, 0.2, seed=42) \
    .toPandas()

# --- Logic for Plot A: Error Analysis ---
df_viz["is_correct"] = df_viz["label"] == df_viz["prediction"]
df_viz["Prediction Status"] = df_viz["is_correct"].map({True: "Correct", False: "Misclassified"})

# --- Logic for Plot B: Prediction Classes ---
label_map = {0.0: "Main Sequence", 1.0: "Red Giant", 2.0: "White Dwarf"}
df_viz["Predicted Star Type"] = df_viz["prediction"].map(label_map)

# Set Global Style
sns.set_style("whitegrid")
custom_red = "#e63946"
custom_med_blue = "#457b9d"
custom_dark_blue = "#1d3557"


# ==========================================
#   PLOT 1: The Error Map
# ==========================================
fig1, ax1 = plt.subplots(figsize=(8, 6))

error_palette = {"Correct": custom_dark_blue, "Misclassified": custom_red}

sns.scatterplot(
    data=df_viz,
    x="bp_rp",
    y="abs_mag",
    hue="Prediction Status",
    palette=error_palette,
    s=15, # Increased size slightly since the plot is bigger now
    alpha=0.5,
    edgecolor=None,
    ax=ax1
)

ax1.set_title("A. Error Analysis (Where did it fail?)", fontsize=14, weight='bold')
ax1.legend(loc="upper right", frameon=True, edgecolor='black', title="Accuracy", fontsize=10, title_fontsize=11)

# Formatting
ax1.invert_yaxis()
ax1.set_xlabel("Color Index ($G_{BP} - G_{RP}$)", fontsize=12)
ax1.set_ylabel("Absolute Magnitude ($M_G$)", fontsize=12)
ax1.grid(True, linestyle='-', alpha=0.3)
ax1.tick_params(axis='both', which='major', labelsize=10)

plt.tight_layout()
plt.show()


# ==========================================
#   PLOT 2: The Prediction Map
# ==========================================
fig2, ax2 = plt.subplots(figsize=(8, 6))

model_palette = {
    "Red Giant": custom_red, 
    "Main Sequence": custom_med_blue,
    "White Dwarf": custom_dark_blue  
}

sns.scatterplot(
    data=df_viz,
    x="bp_rp",
    y="abs_mag",
    hue="Predicted Star Type",
    palette=model_palette,
    hue_order=["Main Sequence", "Red Giant", "White Dwarf"],
    s=15,
    alpha=0.5,
    edgecolor=None,
    ax=ax2
)

ax2.set_title("B. Model Predictions (Linear Boundaries)", fontsize=14, weight='bold')
ax2.legend(loc="upper right", frameon=True, edgecolor='black', title="Predicted Class", fontsize=10, title_fontsize=11)

# Formatting
ax2.invert_yaxis()
ax2.set_xlabel("Color Index ($G_{BP} - G_{RP}$)", fontsize=12)
ax2.set_ylabel("Absolute Magnitude ($M_G$)", fontsize=12)
ax2.grid(True, linestyle='-', alpha=0.3)
ax2.tick_params(axis='both', which='major', labelsize=10)

plt.tight_layout()
plt.show()
(a) Error Analysis
(b) Model Predictions
Figure 3: Side-by-Side Visualization of Predictions vs Truth

3.2.5 Phase 3: Comparing Models

Code
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# --- 1. Define Evaluation Function ---
def get_metrics(predictions, model_name):
    # Initialize evaluators
    acc_eval = MulticlassClassificationEvaluator(metricName="accuracy")
    f1_eval = MulticlassClassificationEvaluator(metricName="f1")
    prec_eval = MulticlassClassificationEvaluator(metricName="weightedPrecision")
    rec_eval = MulticlassClassificationEvaluator(metricName="weightedRecall")
    
    # Calculate scores
    return {
        "Model": model_name,
        "Accuracy": acc_eval.evaluate(predictions),
        "F1 Score": f1_eval.evaluate(predictions),
        "Precision": prec_eval.evaluate(predictions),
        "Recall": rec_eval.evaluate(predictions)
    }

# --- 2. Collect Data ---
print(">> Calculating metrics for comparison...")
# Assuming 'predictions' is from Random Forest and 'predictions_lr' is from Logistic Regression
metrics_rf = get_metrics(predictions, "Random Forest")
metrics_lr = get_metrics(predictions_lr, "Logistic Regression")

# Create DataFrame
df_metrics = pd.DataFrame([metrics_rf, metrics_lr])

# Melt for plotting (Long format)
df_melted = df_metrics.melt(id_vars="Model", var_name="Metric", value_name="Score")
>> Calculating metrics for comparison...
25/12/18 09:40:31 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 1052:==================================================>   (13 + 1) / 14]                                                                                25/12/18 09:40:32 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 1054:==================================================>   (13 + 1) / 14]                                                                                25/12/18 09:40:33 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 1056:==================================================>   (13 + 1) / 14]                                                                                25/12/18 09:40:34 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 1058:==================================================>   (13 + 1) / 14]                                                                                
Code
# --- 3. Create Comparison Visual ---
plt.figure(figsize=(10, 6))
sns.set_style("whitegrid")

# Custom Palette: Dark Blue vs Red
custom_palette = ["#1d3557", "#e63946"] 

ax = sns.barplot(
    data=df_melted,
    x="Metric",
    y="Score",
    hue="Model",
    palette=custom_palette,
    edgecolor="black",
    linewidth=0.8
)

# --- 4. Aesthetics ---
plt.title("Performance Showdown: Random Forest vs. Logistic Regression", fontsize=14, weight='bold', pad=15)
plt.ylabel("Score (0.0 - 1.0)", fontsize=12)
plt.xlabel("Evaluation Metric", fontsize=12)
plt.ylim(0.8, 1.0) # Zoom in to show differences clearly (Adjust if scores are lower)
plt.legend(title="Algorithm", loc="lower right", frameon=True, edgecolor='black')

# Add values on top of bars
for container in ax.containers:
    ax.bar_label(container, fmt='%.3f', padding=3, fontsize=10)

plt.tight_layout()
plt.show()

# Print Table for Report
print("\n=== FINAL COMPARISON TABLE ===")
print(df_metrics.round(4))
Figure 4: Model Comparison

=== FINAL COMPARISON TABLE ===
                 Model  Accuracy  F1 Score  Precision  Recall
0        Random Forest    0.9848    0.9851     0.9861  0.9848
1  Logistic Regression    0.9202    0.9292     0.9519  0.9202

Comparision Analysis

Conclusion: Which Model Performed Better?

The Random Forest Classifier is the superior model for this prediction goal.

While the Logistic Regression provides a useful baseline and helps explain linear relationships , the Random Forest is better option to the specific scientific nature of the data.

  • Performance: As shown in the bar chart, the Random Forest model consistently achieves higher scores across F1 and Accuracy.
  • Reasoning: The physical boundaries between star types (specifically Main Sequence vs. Red Giants) are curved, not straight. Random Forest can perform well on these complex.
  • Weakness of Logistic Regression: The Logistic Regression model attempts to draw a straight line through the data. As seen in the “Error Map” previously, this linear boundary cuts through the curved Red Giant branch, leading to higher misclassification rates in that specific region.

3.3 Jayrup

The objective of this task is to predict stellar parallax using features derived from the Gaia dataset. Parallax is a continuous variable and therefore the task is formulated as a regression problem. Predicting parallax allows indirect estimation of stellar distance, which is a fundamental problem in astrophysics.

Understanding the data

The following features were selected based on their relevance and exploratory data analysis:

  • Total Proper Motion(\(\mu\)): Computed as

\[ \mu = \sqrt{\text{pmra}^2 + \text{pmra}^2} \]

Proper motion provides kinematic information related to stellar distance.

  • Photometric G-band Magnitude (phot_g_mean_mag): Represents apparent brightness, which is distance dependent.

  • Colour Index (bp_rp): Used as a proxy for stellar temperature and population type.

  • Effective Temperature (teff_gspphot): Provides additional physical context to distinguish between stellar populations.

Code
spark = SparkSession.builder \
    .appName("Parallax_Prediction") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

df = spark.read.parquet("../data/gaia_100pc.parquet")

# Show basic stats
df.select('pmra', 'pmdec', 'parallax').describe().show()

# Count total rows
print(f"Total rows: {df.count()}")
25/12/18 09:40:37 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
+-------+------------------+-------------------+------------------+
|summary|              pmra|              pmdec|          parallax|
+-------+------------------+-------------------+------------------+
|  count|            541958|             541958|            541958|
|   mean|-3.085397329690918|-23.233364391692742|14.283010896409225|
| stddev|  94.7899342646246|   88.4285353154643|7.0432588218793475|
|    min|-4406.469178827325|-5817.8001940492695|10.000005410606198|
|    max| 6765.995136250774| 10362.394206546573| 768.0665391873573|
+-------+------------------+-------------------+------------------+

Total rows: 541958
Code
from pyspark.sql.functions import sqrt, col

# Calculate total proper motion magnitude
df = df.withColumn("total_motion", sqrt(col("pmra")**2 + col("pmdec")**2))

# Check the results
df.select('pmra', 'pmdec', 'total_motion').describe().show()
+-------+------------------+-------------------+-------------------+
|summary|              pmra|              pmdec|       total_motion|
+-------+------------------+-------------------+-------------------+
|  count|            541958|             541958|             541958|
|   mean|-3.085397329690918|-23.233364391692742|  72.42738578220973|
| stddev|  94.7899342646246|   88.4285353154643|   110.037773135682|
|    min|-4406.469178827325|-5817.8001940492695|0.00949718093624809|
|    max| 6765.995136250774| 10362.394206546573| 10393.348722273944|
+-------+------------------+-------------------+-------------------+

Visualizing the data

Code
from pyspark.sql.functions import col, sqrt, mean, stddev
import seaborn as sns
import matplotlib.pyplot as plt

# First, let's calculate total motion
df = df.withColumn("total_motion", sqrt(col("pmra")**2 + col("pmdec")**2))

# Convert to Pandas for plotting
pdf = df.select("total_motion").toPandas()

# Plot using Seaborn or Matplotlib
# plt.figure(figsize=(10, 6))
sns.histplot(pdf["total_motion"], bins=100, kde=True)
plt.show()
import numpy as np

# Create 100 bins spaced logarithmically from the min to the max of data
bins = np.logspace(np.log10(pdf["total_motion"].min()), 
                   np.log10(pdf["total_motion"].max()), 
                   100)

# plt.figure(figsize=(10, 6))
plt.hist(pdf["total_motion"], bins=bins)
plt.xscale('log') 
plt.xlabel("Total Motion (Log Scale)")
plt.ylabel("Frequency")
plt.show()

Distribution of Total Motion

Distribution of Total Motion

Distribution of Total Motion (Log Scale)

Distribution of Total Motion (Log Scale)

Distribution of Total Motion

We can gain the following insights from the plots:

  • The linear plot (figure 1) exhibits a massive “right-skew,” compressing 99% of the data into the first bin. This causes the model to over-weight the few extreme outliers (high motion stars) while treating the vast majority of the dataset as having near-zero variance.

  • The log transformation (Figure 2) reveals that the underlying data distribution is actually bimodal (two distinct peaks). The linear plot mathematically hid this physical distinction between the background stars (first peak) and the nearby high-proper-motion stars (second peak).

  • Regression algorithms (especially Linear Regression) assume constant variance across the range of values. The raw proper motion spans several orders of magnitude (from 0.01 to 10,000+); applying a log transform stabilizes the variance, making the error metrics (RMSE) meaningful across the entire dataset rather than just for the fastest stars.

Code
# Filter for extreme outliers based on the plot's X-axis
outliers_df = df.filter(col("total_motion") > 1000)

print(f"Found {outliers_df.count()} extreme outliers.")
outliers_df.sort(col("total_motion").desc()).show(10)
Found 628 extreme outliers.
+-------------------+------------------+------------------+------------------+--------------+------------------+-------------------+---------------+---------+------------+------------------+
|          source_id|                ra|               dec|          parallax|parallax_error|              pmra|              pmdec|phot_g_mean_mag|    bp_rp|teff_gspphot|      total_motion|
+-------------------+------------------+------------------+------------------+--------------+------------------+-------------------+---------------+---------+------------+------------------+
|4472832130942575872|269.44850252543836| 4.739420051112412|  546.975939730948|   0.040116355|-801.5509783684709| 10362.394206546573|      8.1939745|2.8336968|   3099.6335|10393.348722273944|
|4810594479418041856|  77.9599373502188| -45.0438126993602|254.19859326384577|   0.016842743| 6491.223339061598| -5708.614150045243|       8.063552|2.0266457|   3451.8704| 8644.319287929779|
|4034171629042489088|178.26735320817272| 37.69282694689086|109.02963997046682|   0.019686269| 4002.654640989075|-5817.8001940492695|      6.1985016|1.0016494|   5043.2183| 7061.730897797727|
|6553614253923452800| 346.5039166796005| -35.8471642082214| 304.1353692001036|    0.01999573| 6765.995136250774| 1330.2852747179845|      6.5220323|2.0982852|   3376.0845|6895.5310959998305|
|2306965202564744064|1.3832841523481234|-37.36774402806293|230.09703402875448|   0.036182754| 5633.438087895326|-2334.7212726520424|      7.6824937| 2.186049|   3355.3533| 6098.077411047167|
|1872046609345556480| 316.7484792940004| 38.76386244649797|285.99494829578117|    0.05989728|4164.2086922846665|  3249.613883848584|       4.766713|1.4625897|   4353.7437| 5282.104166617755|
|3098328182579892096|122.99468256110134| 8.750401522062495|147.72184850183513|   0.094956644| 1069.811738087307| -5094.220103378359|      11.397289| 2.992114|        NULL| 5205.341066310027|
|1872046574983497216|  316.753662752556| 38.75607277205679| 286.0053518616485|   0.028940246| 4105.976428209489| 3155.9416398273515|      5.4506445|1.7153406|   3889.6328| 5178.707373757288|
|  35227046884571776| 43.26964247679057| 16.86437381897744|260.98844068047276|    0.09342672|3429.0828268077694|  -3805.54112273733|      12.263103|4.5285025|        NULL| 5122.572817437822|
| 762815470562110464|165.83095967577933|35.948653032660104|392.75294543876464|    0.03206665|-580.0570872139048| -4776.588719443488|       6.551172|2.2156086|    3511.045| 4811.680165923527|
+-------------------+------------------+------------------+------------------+--------------+------------------+-------------------+---------------+---------+------------+------------------+
only showing top 10 rows

This confirms we are looking at real, high-quality data: - Barnard’s Star: The first outlier in our table (Source ID 4472832130942575872, total_motion ~10,393) is almost certainly Barnard’s Star, which has the highest proper motion of any known star. - The 100pc limit: Min parallax is ~10.0, which corresponds exactly to a distance of 100 parsecs (\(d = 1000/ \pi\)), matching our filename gaia_100pc.

Processing

To reduce skewness and improve model stability, the following logarithmic transformations have been applied.

Code
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression, RandomForestRegressor
from pyspark.ml.regression import LinearRegressionModel, RandomForestRegressionModel
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.functions import log10


# Add log columns to handle the huge range of motion
df_ml = df.withColumn("log_motion", log10(col("total_motion") + 1))
df_ml = df_ml.withColumn("log_parallax", log10(col("parallax") + 1))

df_ml = df_ml.na.drop(subset=["bp_rp", "phot_g_mean_mag", "total_motion","teff_gspphot"])

# Prepare Features (Using the bimodal motion and magnitude)
assembler = VectorAssembler(
    inputCols=["log_motion", "phot_g_mean_mag","bp_rp"], 
    outputCol="features", 
    handleInvalid="skip"
)
data = assembler.transform(df_ml)

# Split into Training (80%) and Testing (20%)
train_df, test_df = data.randomSplit([0.8, 0.2], seed=42)

lr_path = "models/gaia_linear_regression"
rf_path = "models/gaia_random_forest"

if os.path.exists(lr_path):
  # Load Model A: Linear Regression
  lr_model = LinearRegressionModel.load(lr_path)
else:
  # Train Model A: Linear Regression
  lr = LinearRegression(featuresCol="features", labelCol="log_parallax")
  lr_model = lr.fit(train_df)

  lr_model.write().overwrite().save(lr_path)
  print(f"Model saved to {lr_path}")

if os.path.exists(rf_path):
  # Load Model B: Random Forest
  rf_model = RandomForestRegressionModel.load(rf_path)
else:
  # Train Model B: Random Forest
  rf = RandomForestRegressor(
      featuresCol="features",
      labelCol="log_parallax",
      numTrees=100,
      maxDepth=12,
      # minInstancesPerNode=20,
      # featureSubsetStrategy="sqrt",
      seed=42
  )
  rf_model = rf.fit(train_df)

  rf_model.write().overwrite().save(rf_path)
  print(f"Model saved to {rf_path}")

# 5. Get Predictions
lr_pred = lr_model.transform(test_df)
rf_pred = rf_model.transform(test_df)

# 6. Evaluate
eval_r2 = RegressionEvaluator(labelCol="log_parallax", predictionCol="prediction", metricName="r2")
eval_rmse = RegressionEvaluator(labelCol="log_parallax", predictionCol="prediction", metricName="rmse")

print("Linear Regression R2:", eval_r2.evaluate(lr_pred))
print("Random Forest R2:   ", eval_r2.evaluate(rf_pred))
print("-" * 30)
print("Linear Regression RMSE:", eval_rmse.evaluate(lr_pred))
print("Random Forest RMSE:   ", eval_rmse.evaluate(rf_pred))
Linear Regression R2: 0.6381499150350736
[Stage 1098:=================================================>      (8 + 1) / 9]                                                                                
Random Forest R2:    0.7552582025878004
------------------------------
Linear Regression RMSE: 0.0882892047246452
Random Forest RMSE:    0.07261015068300783
[Stage 1102:=================================================>      (8 + 1) / 9]                                                                                
Code
import matplotlib.pyplot as plt
import seaborn as sns

# Convert Spark DataFrame to Pandas for plotting
pdf = rf_pred.select("log_parallax", "prediction", "bp_rp").sample(fraction=0.05, seed=42).toPandas()

# Scatter plot: predicted vs true log_parallax
plt.figure(figsize=(8, 6))
scatter = plt.scatter(
    pdf["log_parallax"],
    pdf["prediction"],
    c=pdf["bp_rp"],        # color by bp_rp
    cmap="viridis",
    s=10,
    alpha=0.6
)
plt.plot([pdf["log_parallax"].min(), pdf["log_parallax"].max()],
         [pdf["log_parallax"].min(), pdf["log_parallax"].max()],
         color="red", linestyle="--", label="y = x")
plt.xlabel("True log(Parallax)")
plt.ylabel("Predicted log(Parallax)")
plt.colorbar(scatter, label="bp_rp")
plt.legend()
plt.show()

pdf_lr = lr_pred.select("log_parallax", "prediction", "bp_rp").sample(fraction=0.05, seed=42).toPandas()

plt.figure(figsize=(8,6))
sns.scatterplot(
    x=pdf_lr["log_parallax"], y=pdf_lr["prediction"],
    hue=pdf_lr["bp_rp"], palette="coolwarm", alpha=0.6, s=10
)
plt.plot([pdf_lr["log_parallax"].min(), pdf_lr["log_parallax"].max()],
         [pdf_lr["log_parallax"].min(), pdf_lr["log_parallax"].max()],
         color="black", linestyle="--", label="y = x")
plt.xlabel("True log(Parallax)")
plt.ylabel("Predicted log(Parallax)")
plt.legend()
plt.show()

Random Forest Predictions vs True Values

Random Forest Predictions vs True Values

Linear Regression Predictions vs True Values

Linear Regression Predictions vs True Values

Random Forest Predictions vs Linear Regression Predictions